1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-20 11:37:56 +00:00

Utils: Added HTTP drivers

This commit is contained in:
diamondburned (Forefront) 2020-04-19 14:53:53 -07:00
parent bf93a9cee9
commit 2afe683b7d
11 changed files with 380 additions and 261 deletions

View file

@ -7,6 +7,7 @@ import (
"github.com/diamondburned/arikawa/api/rate"
"github.com/diamondburned/arikawa/utils/httputil"
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
)
var (
@ -22,39 +23,40 @@ var (
var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
type Client struct {
httputil.Client
*httputil.Client
Limiter *rate.Limiter
Token string
Token string
UserAgent string
}
func NewClient(token string) *Client {
return NewCustomClient(token, httputil.NewClient())
}
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
cli := &Client{
Client: httputil.DefaultClient,
Limiter: rate.NewLimiter(APIPath),
Token: token,
Client: httpClient,
Limiter: rate.NewLimiter(APIPath),
Token: token,
UserAgent: UserAgent,
}
tw := httputil.NewTransportWrapper()
tw.Pre = func(r *http.Request) error {
if cli.Token != "" {
r.Header.Set("Authorization", cli.Token)
}
cli.DefaultOptions = []httputil.RequestOption{
func(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Authorization": {cli.Token},
"User-Agent": {cli.UserAgent},
"X-RateLimit-Precision": {"millisecond"},
})
r.Header.Set("User-Agent", UserAgent)
r.Header.Set("X-RateLimit-Precision", "millisecond")
// Rate limit stuff
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
// Rate limit stuff
return cli.Limiter.Acquire(r.GetContext(), r.GetPath())
},
}
tw.Post = func(r *http.Request, resp *http.Response) error {
if resp == nil {
return cli.Limiter.Release(r.URL.Path, nil)
}
return cli.Limiter.Release(r.URL.Path, resp.Header)
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error {
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
}
cli.Client.Transport = tw
return cli
}

View file

@ -32,8 +32,7 @@ type CreateGuildData struct {
func (c *Client) CreateGuild(data CreateGuildData) (*discord.Guild, error) {
var g *discord.Guild
return g, c.RequestJSON(&g, "POST", Endpoint+"guilds",
httputil.WithJSONBody(c, data))
return g, c.RequestJSON(&g, "POST", Endpoint+"guilds", httputil.WithJSONBody(c, data))
}
func (c *Client) Guild(guildID discord.Snowflake) (*discord.Guild, error) {
@ -76,23 +75,17 @@ func (c *Client) Guilds(max uint) ([]discord.Guild, error) {
}
// GuildsBefore fetches guilds. Check GuildsRange.
func (c *Client) GuildsBefore(
before discord.Snowflake, limit uint) ([]discord.Guild, error) {
func (c *Client) GuildsBefore(before discord.Snowflake, limit uint) ([]discord.Guild, error) {
return c.GuildsRange(before, 0, limit)
}
// GuildsAfter fetches guilds. Check GuildsRange.
func (c *Client) GuildsAfter(
after discord.Snowflake, limit uint) ([]discord.Guild, error) {
func (c *Client) GuildsAfter(after discord.Snowflake, limit uint) ([]discord.Guild, error) {
return c.GuildsRange(0, after, limit)
}
// GuildsRange fetches guilds. The limit is 1-100.
func (c *Client) GuildsRange(
before, after discord.Snowflake, limit uint) ([]discord.Guild, error) {
func (c *Client) GuildsRange(before, after discord.Snowflake, limit uint) ([]discord.Guild, error) {
if limit == 0 {
limit = 100
}
@ -163,21 +156,15 @@ func (c *Client) DeleteGuild(guildID discord.Snowflake) error {
// GuildVoiceRegions is the same as /voice, but returns VIP ones as well if
// available.
func (c *Client) VoiceRegionsGuild(
guildID discord.Snowflake) ([]discord.VoiceRegion, error) {
func (c *Client) VoiceRegionsGuild(guildID discord.Snowflake) ([]discord.VoiceRegion, error) {
var vrs []discord.VoiceRegion
return vrs, c.RequestJSON(&vrs, "GET",
EndpointGuilds+guildID.String()+"/regions")
return vrs, c.RequestJSON(&vrs, "GET", EndpointGuilds+guildID.String()+"/regions")
}
// Integrations requires MANAGE_GUILD.
func (c *Client) Integrations(
guildID discord.Snowflake) ([]discord.Integration, error) {
func (c *Client) Integrations(guildID discord.Snowflake) ([]discord.Integration, error) {
var ints []discord.Integration
return ints, c.RequestJSON(&ints, "GET",
EndpointGuilds+guildID.String()+"/integrations")
return ints, c.RequestJSON(&ints, "GET", EndpointGuilds+guildID.String()+"/integrations")
}
// AttachIntegration requires MANAGE_GUILD.
@ -214,46 +201,36 @@ func (c *Client) ModifyIntegration(
return c.FastRequest(
"PATCH",
EndpointGuilds+guildID.String()+
"/integrations/"+integrationID.String(),
EndpointGuilds+guildID.String()+"/integrations/"+integrationID.String(),
httputil.WithSchema(c, param),
)
}
func (c *Client) SyncIntegration(
guildID, integrationID discord.Snowflake) error {
func (c *Client) SyncIntegration(guildID, integrationID discord.Snowflake) error {
return c.FastRequest("POST", EndpointGuilds+guildID.String()+
"/integrations/"+integrationID.String()+"/sync")
}
func (c *Client) GuildEmbed(
guildID discord.Snowflake) (*discord.GuildEmbed, error) {
func (c *Client) GuildEmbed(guildID discord.Snowflake) (*discord.GuildEmbed, error) {
var ge *discord.GuildEmbed
return ge, c.RequestJSON(&ge, "GET",
EndpointGuilds+guildID.String()+"/embed")
return ge, c.RequestJSON(&ge, "GET", EndpointGuilds+guildID.String()+"/embed")
}
// ModifyGuildEmbed should be used with care: if you still want the embed
// enabled, you need to set the Enabled boolean, even if it's already enabled.
// If you don't, JSON will default it to false.
func (c *Client) ModifyGuildEmbed(
guildID discord.Snowflake,
data discord.GuildEmbed) (*discord.GuildEmbed, error) {
return &data, c.RequestJSON(&data, "PATCH",
EndpointGuilds+guildID.String()+"/embed")
// ModifyGuildEmbed modifies the guild embed and updates the passed in
// GuildEmbed data.
//
// This method should be used with care: if you still want the embed enabled,
// you need to set the Enabled boolean, even if it's already enabled. If you
// don't, JSON will default it to false.
func (c *Client) ModifyGuildEmbed(guildID discord.Snowflake, data *discord.GuildEmbed) error {
return c.RequestJSON(&data, "PATCH", EndpointGuilds+guildID.String()+"/embed")
}
// GuildVanityURL returns *Invite, but only Code and Uses are filled. Requires
// MANAGE_GUILD.
func (c *Client) GuildVanityURL(
guildID discord.Snowflake) (*discord.Invite, error) {
func (c *Client) GuildVanityURL(guildID discord.Snowflake) (*discord.Invite, error) {
var inv *discord.Invite
return inv, c.RequestJSON(&inv, "GET",
EndpointGuilds+guildID.String()+"/vanity-url")
return inv, c.RequestJSON(&inv, "GET", EndpointGuilds+guildID.String()+"/vanity-url")
}
type GuildImageType string
@ -266,20 +243,15 @@ const (
GuildBanner4 GuildImageType = "banner4"
)
func (c *Client) GuildImageURL(
guildID discord.Snowflake, img GuildImageType) string {
return EndpointGuilds + guildID.String() +
"/widget.png?style=" + string(img)
func (c *Client) GuildImageURL(guildID discord.Snowflake, img GuildImageType) string {
return EndpointGuilds + guildID.String() + "/widget.png?style=" + string(img)
}
func (c *Client) GuildImage(
guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) {
func (c *Client) GuildImage(guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) {
r, err := c.Request("GET", c.GuildImageURL(guildID, img))
if err != nil {
return nil, err
}
return r.Body, nil
return r.GetBody(), nil
}

View file

@ -27,8 +27,7 @@ func (c *Client) SendMessageComplex(
if len(data.Files) == 0 {
// No files, so no need for streaming.
return msg, c.RequestJSON(&msg, "POST", URL,
httputil.WithJSONBody(c, data))
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data))
}
writer := func(mw *multipart.Writer) error {
@ -40,9 +39,10 @@ func (c *Client) SendMessageComplex(
return nil, err
}
defer resp.Body.Close()
var body = resp.GetBody()
defer body.Close()
return msg, c.DecodeStream(resp.Body, &msg)
return msg, c.DecodeStream(body, &msg)
}
const AttachmentSpoilerPrefix = "SPOILER_"
@ -83,9 +83,7 @@ type ExecuteWebhookData struct {
AvatarURL discord.URL `json:"avatar_url,omitempty"`
}
func (data *ExecuteWebhookData) WriteMultipart(
c json.Driver, body *multipart.Writer) error {
func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error {
return writeMultipart(c, body, data, data.Files)
}

View file

@ -35,28 +35,21 @@ func (c *Client) CreateWebhook(
}
// Webhooks requires MANAGE_WEBHOOKS.
func (c *Client) Webhooks(
guildID discord.Snowflake) ([]discord.Webhook, error) {
func (c *Client) Webhooks(guildID discord.Snowflake) ([]discord.Webhook, error) {
var ws []discord.Webhook
return ws, c.RequestJSON(&ws, "GET",
EndpointGuilds+guildID.String()+"/webhooks")
return ws, c.RequestJSON(&ws, "GET", EndpointGuilds+guildID.String()+"/webhooks")
}
func (c *Client) Webhook(
webhookID discord.Snowflake) (*discord.Webhook, error) {
func (c *Client) Webhook(webhookID discord.Snowflake) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET",
EndpointWebhooks+webhookID.String())
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String())
}
func (c *Client) WebhookWithToken(
webhookID discord.Snowflake, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "GET",
EndpointWebhooks+webhookID.String()+"/"+token)
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String()+"/"+token)
}
type ModifyWebhookData struct {
@ -70,8 +63,7 @@ func (c *Client) ModifyWebhook(
data ModifyWebhookData) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "PATCH",
EndpointWebhooks+webhookID.String())
return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String())
}
func (c *Client) ModifyWebhookWithToken(
@ -79,26 +71,24 @@ func (c *Client) ModifyWebhookWithToken(
data ModifyWebhookData, token string) (*discord.Webhook, error) {
var w *discord.Webhook
return w, c.RequestJSON(&w, "PATCH",
EndpointWebhooks+webhookID.String()+"/"+token)
return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String()+"/"+token)
}
func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String())
}
func (c *Client) DeleteWebhookWithToken(
webhookID discord.Snowflake, token string) error {
return c.FastRequest("DELETE",
EndpointWebhooks+webhookID.String()+"/"+token)
func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error {
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token)
}
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
// wait for the message to be delivered and will return the message body. This
// also means the returned message will only be there if wait is true.
func (c *Client) ExecuteWebhook(
webhookID discord.Snowflake, token string, wait bool,
webhookID discord.Snowflake,
token string,
wait bool,
data ExecuteWebhookData) (*discord.Message, error) {
for i, embed := range data.Embeds {
@ -112,8 +102,7 @@ func (c *Client) ExecuteWebhook(
param.Set("wait", "true")
}
var URL = EndpointWebhooks + webhookID.String() + "/" + token +
"?" + param.Encode()
var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode()
var msg *discord.Message
if len(data.Files) == 0 {
@ -131,12 +120,13 @@ func (c *Client) ExecuteWebhook(
return nil, err
}
defer resp.Body.Close()
var body = resp.GetBody()
defer body.Close()
if !wait {
// Since we didn't tell Discord to wait, we have nothing to parse.
return nil, nil
}
return msg, c.DecodeStream(resp.Body, &msg)
return msg, c.DecodeStream(body, &msg)
}

View file

@ -70,22 +70,22 @@ type SessionStartLimit struct {
ResetAfter discord.Milliseconds `json:"reset_after"`
}
// GatewayURL asks Discord for a Websocket URL to the Gateway.
func GatewayURL() (string, error) {
// URL asks Discord for a Websocket URL to the Gateway.
func URL() (string, error) {
var g GatewayBotData
return g.URL, httputil.DefaultClient.RequestJSON(
return g.URL, httputil.NewClient().RequestJSON(
&g, "GET",
EndpointGateway,
)
}
// GatewayBot fetches the Gateway URL along with extra metadata. The token
// BotURL fetches the Gateway URL along with extra metadata. The token
// passed in will NOT be prefixed with Bot.
func GatewayBot(token string) (*GatewayBotData, error) {
func BotURL(token string) (*GatewayBotData, error) {
var g *GatewayBotData
return g, httputil.DefaultClient.RequestJSON(
return g, httputil.NewClient().RequestJSON(
&g, "GET",
EndpointGatewayBot,
httputil.WithHeaders(http.Header{
@ -144,7 +144,7 @@ func NewGateway(token string) (*Gateway, error) {
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
URL, err := GatewayURL()
URL, err := URL()
if err != nil {
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
}

View file

@ -46,7 +46,7 @@ func New(token string) (*Session, error) {
s.Handler = handler.New()
s.Client = api.NewClient(token)
// Open a gateway
// Create a gateway
g, err := gateway.NewGateway(token)
if err != nil {
return nil, errors.Wrap(err, "Failed to connect to Gateway")

View file

@ -3,44 +3,69 @@
package httputil
import (
"bytes"
"context"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"time"
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
"github.com/diamondburned/arikawa/utils/json"
"github.com/pkg/errors"
)
// Retries is the default attempts to retry if the API returns an error before
// giving up.
// giving up. If the value is smaller than 1, then requests will retry forever.
var Retries uint = 5
type Client struct {
http.Client
httpdriver.Client
json.Driver
SchemaEncoder
// DefaultOptions, if not nil, will be copied and prefixed on each Request.
DefaultOptions []RequestOption
// OnResponse is called after every Do() call. Response might be nil if Do()
// errors out. The error returned will override Do's if it's not nil.
OnResponse func(httpdriver.Request, httpdriver.Response) error
// Default to the global Retries variable (5).
Retries uint
}
var DefaultClient = NewClient()
// ResponseNoop is used for (*Client).OnResponse.
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
return nil
}
func NewClient() Client {
return Client{
Client: http.Client{
Timeout: 10 * time.Second,
},
func NewClient() *Client {
return &Client{
Client: httpdriver.NewClient(),
Driver: json.Default{},
SchemaEncoder: &DefaultSchema{},
Retries: Retries,
OnResponse: ResponseNoop,
}
}
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
for _, opt := range c.DefaultOptions {
if err := opt(r); err != nil {
return err
}
}
for _, opt := range extra {
if err := opt(r); err != nil {
return err
}
}
return nil
}
func (c *Client) MeanwhileMultipart(
multipartWriter func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (*http.Response, error) {
writer func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// We want to cancel the request if our bodyWriter fails
ctx, cancel := context.WithCancel(context.Background())
@ -52,7 +77,7 @@ func (c *Client) MeanwhileMultipart(
var bgErr error
go func() {
if err := multipartWriter(body); err != nil {
if err := writer(body); err != nil {
bgErr = err
cancel()
}
@ -61,119 +86,122 @@ func (c *Client) MeanwhileMultipart(
w.Close()
}()
resp, err := c.RequestCtx(ctx, method, url,
append([]RequestOption{
WithBody(r),
WithContentType(body.FormDataContentType()),
}, opts...)...)
// Prepend the multipart writer and the correct Content-Type header options.
opts = PrependOptions(
opts,
WithBody(r),
WithContentType(body.FormDataContentType()),
)
resp, err := c.RequestCtx(ctx, method, url, opts...)
if err != nil && bgErr != nil {
if resp.Body != nil {
resp.Body.Close()
}
return nil, bgErr
}
return resp, err
}
func (c *Client) FastRequest(
method, url string, opts ...RequestOption) error {
func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
r, err := c.Request(method, url, opts...)
if err != nil {
return err
}
return r.Body.Close()
return r.GetBody().Close()
}
func (c *Client) RequestCtx(ctx context.Context,
method, url string, opts ...RequestOption) (*http.Response, error) {
func (c *Client) RequestCtxJSON(
ctx context.Context,
to interface{}, method, url string, opts ...RequestOption) error {
req, err := http.NewRequestWithContext(ctx, method, url, nil)
opts = PrependOptions(opts, JSONRequest)
r, err := c.RequestCtx(ctx, method, url, opts...)
if err != nil {
return err
}
var body, status = r.GetBody(), r.GetStatus()
defer body.Close()
// No content, working as intended (tm)
if status == httpdriver.NoContent {
return nil
}
if err := c.DecodeStream(body, to); err != nil {
return JSONError{err}
}
return nil
}
func (c *Client) RequestCtx(
ctx context.Context,
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
req, err := c.Client.NewRequest(ctx, method, url)
if err != nil {
return nil, RequestError{err}
}
for _, opt := range opts {
if err := opt(req); err != nil {
return nil, err
}
if err := c.applyOptions(req, opts); err != nil {
return nil, errors.Wrap(err, "Failed to apply options")
}
var r *http.Response
var r httpdriver.Response
var status int
for i := uint(0); i < c.Retries; i++ {
for i := uint(0); c.Retries < 1 || i < c.Retries; i++ {
r, err = c.Client.Do(req)
if err != nil {
continue
}
if r.StatusCode < 200 || r.StatusCode > 299 {
if status = r.GetStatus(); status < 200 || status > 299 {
continue
}
break
}
// Call OnResponse() even if the request failed.
if err := c.OnResponse(req, r); err != nil {
return nil, err
}
// If all retries failed:
if err != nil {
return nil, RequestError{err}
}
// Response received, but with a failure status code:
if r.StatusCode < 200 || r.StatusCode > 299 {
if status < 200 || status > 299 {
// Try and parse the body.
var body = r.GetBody()
defer body.Close()
// This rarely happens, so we can (probably) make an exception for it.
buf := bytes.Buffer{}
buf.ReadFrom(body)
httpErr := &HTTPError{
Status: r.StatusCode,
Status: status,
Body: buf.Bytes(),
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, httpErr
}
// Optionally unmarshal the error.
c.Unmarshal(httpErr.Body, &httpErr)
httpErr.Body = b
c.Unmarshal(b, &httpErr)
return nil, httpErr
}
return r, nil
}
func (c *Client) RequestCtxJSON(ctx context.Context,
to interface{}, method, url string, opts ...RequestOption) error {
r, err := c.RequestCtx(ctx, method, url,
append([]RequestOption{JSONRequest}, opts...)...)
if err != nil {
return err
}
defer r.Body.Close()
// No content, working as intended (tm)
if r.StatusCode == http.StatusNoContent {
return nil
}
if err := c.DecodeStream(r.Body, to); err != nil {
return JSONError{err}
}
return nil
}
func (c *Client) Request(
method, url string, opts ...RequestOption) (*http.Response, error) {
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
return c.RequestCtx(context.Background(), method, url, opts...)
}
func (c *Client) RequestJSON(
to interface{}, method, url string, opts ...RequestOption) error {
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
}

View file

@ -1,36 +0,0 @@
package httputil
import (
"net/http"
)
type TransportWrapper struct {
Default http.RoundTripper
Pre func(*http.Request) error
Post func(*http.Request, *http.Response) error
}
var _ http.RoundTripper = (*TransportWrapper)(nil)
func NewTransportWrapper() *TransportWrapper {
return &TransportWrapper{
Default: http.DefaultTransport,
Pre: func(*http.Request) error { return nil },
Post: func(*http.Request, *http.Response) error { return nil },
}
}
func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) {
if err := c.Pre(req); err != nil {
return nil, err
}
r, err = c.Default.RoundTrip(req)
// Call Post regardless of error:
if postErr := c.Post(req, r); postErr != nil {
return r, postErr
}
return r, err
}

View file

@ -0,0 +1,103 @@
package httpdriver
import (
"context"
"io"
"net/http"
"net/url"
"time"
)
// DefaultClient implements Client and wraps around the stdlib Client.
type DefaultClient struct {
Client http.Client
}
var _ Client = (*DefaultClient)(nil)
// WrapClient wraps around the standard library's http.Client and returns an
// implementation that's compatible with the Client driver interface.
func WrapClient(client http.Client) Client {
return &DefaultClient{
Client: client,
}
}
// NewClient creates a new client around the standard library's http.Client. The
// client will have a timeout of 10 seconds.
func NewClient() Client {
return WrapClient(http.Client{
Timeout: 10 * time.Second,
})
}
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
r, err := http.NewRequestWithContext(ctx, method, url, nil)
if err != nil {
return nil, err
}
return (*DefaultRequest)(r), nil
}
func (d *DefaultClient) Do(req Request) (Response, error) {
// Implementations can safely assert this.
request := req.(*DefaultRequest)
r, err := d.Client.Do((*http.Request)(request))
if err != nil {
return nil, err
}
return (*DefaultResponse)(r), nil
}
// DefaultRequest wraps around the stdlib Request and satisfies the Request
// interface.
type DefaultRequest http.Request
var _ Request = (*DefaultRequest)(nil)
func (r *DefaultRequest) GetPath() string {
return r.URL.Path
}
func (r *DefaultRequest) GetContext() context.Context {
return (*http.Request)(r).Context()
}
func (r *DefaultRequest) AddQuery(values url.Values) {
var qs = r.URL.Query()
for k, v := range values {
qs[k] = append(qs[k], v...)
}
r.URL.RawQuery = qs.Encode()
}
func (r *DefaultRequest) AddHeader(header http.Header) {
for key, values := range header {
r.Header[key] = append(r.Header[key], values...)
}
}
func (r *DefaultRequest) WithBody(body io.ReadCloser) {
r.Body = body
}
// DefaultResponse wraps around the stdlib Response and satisfies the Response
// interface.
type DefaultResponse http.Response
var _ Response = (*DefaultResponse)(nil)
func (r *DefaultResponse) GetStatus() int {
return r.StatusCode
}
func (r *DefaultResponse) GetHeader() http.Header {
return r.Header
}
func (r *DefaultResponse) GetBody() io.ReadCloser {
return r.Body
}

View file

@ -0,0 +1,55 @@
// Package httpdriver provides interfaces and implementations of a simple HTTP
// client.
package httpdriver
import (
"context"
"io"
"net/http"
"net/url"
)
// NoContent is the status code for HTTP 204, or http.StatusNoContent.
const NoContent = 204
// Client is a generic interface used as an adapter allowing for custom HTTP
// client implementations, such as fasthttp.
type Client interface {
NewRequest(ctx context.Context, method, url string) (Request, error)
Do(req Request) (Response, error)
}
// Request is a generic interface for a normal HTTP request. It should be
// constructed using (Requester).NewRequest().
type Request interface {
// GetPath should return the URL path, for example "/endpoint/thing".
GetPath() string
// GetContext should return the same context that's passed into NewRequest.
// For implementations that don't support this, it can remove a
// context.Background().
GetContext() context.Context
// AddHeader appends headers.
AddHeader(http.Header)
// AddQuery appends URL query values.
AddQuery(url.Values)
// WithBody should automatically close the ReadCloser on finish. This is
// similar to the stdlib's Request behavior.
WithBody(io.ReadCloser)
}
// Response is returned from (Requester).DoContext().
type Response interface {
GetStatus() int
GetHeader() http.Header
// Body's ReadCloser will always be closed when done, unless DoContext()
// returns an error.
GetBody() io.ReadCloser
}
// OptHeader returns the response header, or nil if from is nil.
func OptHeader(from Response) http.Header {
if from == nil {
return nil
}
return from.GetHeader()
}

View file

@ -4,86 +4,93 @@ import (
"io"
"net/http"
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
"github.com/diamondburned/arikawa/utils/json"
)
type RequestOption func(*http.Request) error
type RequestOption func(httpdriver.Request) error
func JSONRequest(r *http.Request) error {
r.Header.Set("Content-Type", "application/json")
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
if len(opts) == 0 {
return prepend
}
return append(prepend, opts...)
}
func JSONRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Content-Type": {"application/json"},
})
return nil
}
func MultipartRequest(r *http.Request) error {
r.Header.Set("Content-Type", "multipart/form-data")
func MultipartRequest(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Content-Type": {"multipart/form-data"},
})
return nil
}
func WithHeaders(headers http.Header) RequestOption {
return func(r *http.Request) error {
for key, values := range headers {
r.Header[key] = append(r.Header[key], values...)
}
return func(r httpdriver.Request) error {
r.AddHeader(headers)
return nil
}
}
func WithContentType(ctype string) RequestOption {
return func(r *http.Request) error {
r.Header.Set("Content-Type", ctype)
return func(r httpdriver.Request) error {
r.AddHeader(http.Header{
"Content-Type": {ctype},
})
return nil
}
}
func WithSchema(schema SchemaEncoder, v interface{}) RequestOption {
return func(r *http.Request) error {
return func(r httpdriver.Request) error {
params, err := schema.Encode(v)
if err != nil {
return err
}
var qs = r.URL.Query()
for k, v := range params {
qs[k] = append(qs[k], v...)
}
r.URL.RawQuery = qs.Encode()
r.AddQuery(params)
return nil
}
}
func WithBody(body io.ReadCloser) RequestOption {
return func(r *http.Request) error {
// tee := io.TeeReader(body, os.Stderr)
// r.Body = ioutil.NopCloser(tee)
r.Body = body
r.ContentLength = -1
return func(r httpdriver.Request) error {
r.WithBody(body)
return nil
}
}
// WithJSONBody inserts a JSON body into the request. This ignores JSON errors.
func WithJSONBody(json json.Driver, v interface{}) RequestOption {
if v == nil {
return func(*http.Request) error {
return func(httpdriver.Request) error {
return nil
}
}
var err error
var rp, wp = io.Pipe()
go func() {
err = json.EncodeStream(wp, v)
json.EncodeStream(wp, v)
wp.Close()
}()
return func(r *http.Request) error {
if err != nil {
return err
}
return func(r httpdriver.Request) error {
// TODO: maybe do something to this?
// if err != nil {
// return err
// }
r.Header.Set("Content-Type", "application/json")
r.Body = rp
r.AddHeader(http.Header{
"Content-Type": {"application/json"},
})
r.WithBody(rp)
return nil
}
}