mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-20 11:37:56 +00:00
Utils: Added HTTP drivers
This commit is contained in:
parent
bf93a9cee9
commit
2afe683b7d
46
api/api.go
46
api/api.go
|
@ -7,6 +7,7 @@ import (
|
|||
|
||||
"github.com/diamondburned/arikawa/api/rate"
|
||||
"github.com/diamondburned/arikawa/utils/httputil"
|
||||
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -22,39 +23,40 @@ var (
|
|||
var UserAgent = "DiscordBot (https://github.com/diamondburned/arikawa, v0.0.1)"
|
||||
|
||||
type Client struct {
|
||||
httputil.Client
|
||||
*httputil.Client
|
||||
Limiter *rate.Limiter
|
||||
|
||||
Token string
|
||||
Token string
|
||||
UserAgent string
|
||||
}
|
||||
|
||||
func NewClient(token string) *Client {
|
||||
return NewCustomClient(token, httputil.NewClient())
|
||||
}
|
||||
|
||||
func NewCustomClient(token string, httpClient *httputil.Client) *Client {
|
||||
cli := &Client{
|
||||
Client: httputil.DefaultClient,
|
||||
Limiter: rate.NewLimiter(APIPath),
|
||||
Token: token,
|
||||
Client: httpClient,
|
||||
Limiter: rate.NewLimiter(APIPath),
|
||||
Token: token,
|
||||
UserAgent: UserAgent,
|
||||
}
|
||||
|
||||
tw := httputil.NewTransportWrapper()
|
||||
tw.Pre = func(r *http.Request) error {
|
||||
if cli.Token != "" {
|
||||
r.Header.Set("Authorization", cli.Token)
|
||||
}
|
||||
cli.DefaultOptions = []httputil.RequestOption{
|
||||
func(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Authorization": {cli.Token},
|
||||
"User-Agent": {cli.UserAgent},
|
||||
"X-RateLimit-Precision": {"millisecond"},
|
||||
})
|
||||
|
||||
r.Header.Set("User-Agent", UserAgent)
|
||||
r.Header.Set("X-RateLimit-Precision", "millisecond")
|
||||
|
||||
// Rate limit stuff
|
||||
return cli.Limiter.Acquire(r.Context(), r.URL.Path)
|
||||
// Rate limit stuff
|
||||
return cli.Limiter.Acquire(r.GetContext(), r.GetPath())
|
||||
},
|
||||
}
|
||||
tw.Post = func(r *http.Request, resp *http.Response) error {
|
||||
if resp == nil {
|
||||
return cli.Limiter.Release(r.URL.Path, nil)
|
||||
}
|
||||
return cli.Limiter.Release(r.URL.Path, resp.Header)
|
||||
cli.OnResponse = func(r httpdriver.Request, resp httpdriver.Response) error {
|
||||
return cli.Limiter.Release(r.GetPath(), httpdriver.OptHeader(resp))
|
||||
}
|
||||
|
||||
cli.Client.Transport = tw
|
||||
|
||||
return cli
|
||||
}
|
||||
|
|
80
api/guild.go
80
api/guild.go
|
@ -32,8 +32,7 @@ type CreateGuildData struct {
|
|||
|
||||
func (c *Client) CreateGuild(data CreateGuildData) (*discord.Guild, error) {
|
||||
var g *discord.Guild
|
||||
return g, c.RequestJSON(&g, "POST", Endpoint+"guilds",
|
||||
httputil.WithJSONBody(c, data))
|
||||
return g, c.RequestJSON(&g, "POST", Endpoint+"guilds", httputil.WithJSONBody(c, data))
|
||||
}
|
||||
|
||||
func (c *Client) Guild(guildID discord.Snowflake) (*discord.Guild, error) {
|
||||
|
@ -76,23 +75,17 @@ func (c *Client) Guilds(max uint) ([]discord.Guild, error) {
|
|||
}
|
||||
|
||||
// GuildsBefore fetches guilds. Check GuildsRange.
|
||||
func (c *Client) GuildsBefore(
|
||||
before discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
|
||||
func (c *Client) GuildsBefore(before discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
return c.GuildsRange(before, 0, limit)
|
||||
}
|
||||
|
||||
// GuildsAfter fetches guilds. Check GuildsRange.
|
||||
func (c *Client) GuildsAfter(
|
||||
after discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
|
||||
func (c *Client) GuildsAfter(after discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
return c.GuildsRange(0, after, limit)
|
||||
}
|
||||
|
||||
// GuildsRange fetches guilds. The limit is 1-100.
|
||||
func (c *Client) GuildsRange(
|
||||
before, after discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
|
||||
func (c *Client) GuildsRange(before, after discord.Snowflake, limit uint) ([]discord.Guild, error) {
|
||||
if limit == 0 {
|
||||
limit = 100
|
||||
}
|
||||
|
@ -163,21 +156,15 @@ func (c *Client) DeleteGuild(guildID discord.Snowflake) error {
|
|||
|
||||
// GuildVoiceRegions is the same as /voice, but returns VIP ones as well if
|
||||
// available.
|
||||
func (c *Client) VoiceRegionsGuild(
|
||||
guildID discord.Snowflake) ([]discord.VoiceRegion, error) {
|
||||
|
||||
func (c *Client) VoiceRegionsGuild(guildID discord.Snowflake) ([]discord.VoiceRegion, error) {
|
||||
var vrs []discord.VoiceRegion
|
||||
return vrs, c.RequestJSON(&vrs, "GET",
|
||||
EndpointGuilds+guildID.String()+"/regions")
|
||||
return vrs, c.RequestJSON(&vrs, "GET", EndpointGuilds+guildID.String()+"/regions")
|
||||
}
|
||||
|
||||
// Integrations requires MANAGE_GUILD.
|
||||
func (c *Client) Integrations(
|
||||
guildID discord.Snowflake) ([]discord.Integration, error) {
|
||||
|
||||
func (c *Client) Integrations(guildID discord.Snowflake) ([]discord.Integration, error) {
|
||||
var ints []discord.Integration
|
||||
return ints, c.RequestJSON(&ints, "GET",
|
||||
EndpointGuilds+guildID.String()+"/integrations")
|
||||
return ints, c.RequestJSON(&ints, "GET", EndpointGuilds+guildID.String()+"/integrations")
|
||||
}
|
||||
|
||||
// AttachIntegration requires MANAGE_GUILD.
|
||||
|
@ -214,46 +201,36 @@ func (c *Client) ModifyIntegration(
|
|||
|
||||
return c.FastRequest(
|
||||
"PATCH",
|
||||
EndpointGuilds+guildID.String()+
|
||||
"/integrations/"+integrationID.String(),
|
||||
EndpointGuilds+guildID.String()+"/integrations/"+integrationID.String(),
|
||||
httputil.WithSchema(c, param),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *Client) SyncIntegration(
|
||||
guildID, integrationID discord.Snowflake) error {
|
||||
|
||||
func (c *Client) SyncIntegration(guildID, integrationID discord.Snowflake) error {
|
||||
return c.FastRequest("POST", EndpointGuilds+guildID.String()+
|
||||
"/integrations/"+integrationID.String()+"/sync")
|
||||
}
|
||||
|
||||
func (c *Client) GuildEmbed(
|
||||
guildID discord.Snowflake) (*discord.GuildEmbed, error) {
|
||||
|
||||
func (c *Client) GuildEmbed(guildID discord.Snowflake) (*discord.GuildEmbed, error) {
|
||||
var ge *discord.GuildEmbed
|
||||
return ge, c.RequestJSON(&ge, "GET",
|
||||
EndpointGuilds+guildID.String()+"/embed")
|
||||
return ge, c.RequestJSON(&ge, "GET", EndpointGuilds+guildID.String()+"/embed")
|
||||
}
|
||||
|
||||
// ModifyGuildEmbed should be used with care: if you still want the embed
|
||||
// enabled, you need to set the Enabled boolean, even if it's already enabled.
|
||||
// If you don't, JSON will default it to false.
|
||||
func (c *Client) ModifyGuildEmbed(
|
||||
guildID discord.Snowflake,
|
||||
data discord.GuildEmbed) (*discord.GuildEmbed, error) {
|
||||
|
||||
return &data, c.RequestJSON(&data, "PATCH",
|
||||
EndpointGuilds+guildID.String()+"/embed")
|
||||
// ModifyGuildEmbed modifies the guild embed and updates the passed in
|
||||
// GuildEmbed data.
|
||||
//
|
||||
// This method should be used with care: if you still want the embed enabled,
|
||||
// you need to set the Enabled boolean, even if it's already enabled. If you
|
||||
// don't, JSON will default it to false.
|
||||
func (c *Client) ModifyGuildEmbed(guildID discord.Snowflake, data *discord.GuildEmbed) error {
|
||||
return c.RequestJSON(&data, "PATCH", EndpointGuilds+guildID.String()+"/embed")
|
||||
}
|
||||
|
||||
// GuildVanityURL returns *Invite, but only Code and Uses are filled. Requires
|
||||
// MANAGE_GUILD.
|
||||
func (c *Client) GuildVanityURL(
|
||||
guildID discord.Snowflake) (*discord.Invite, error) {
|
||||
|
||||
func (c *Client) GuildVanityURL(guildID discord.Snowflake) (*discord.Invite, error) {
|
||||
var inv *discord.Invite
|
||||
return inv, c.RequestJSON(&inv, "GET",
|
||||
EndpointGuilds+guildID.String()+"/vanity-url")
|
||||
return inv, c.RequestJSON(&inv, "GET", EndpointGuilds+guildID.String()+"/vanity-url")
|
||||
}
|
||||
|
||||
type GuildImageType string
|
||||
|
@ -266,20 +243,15 @@ const (
|
|||
GuildBanner4 GuildImageType = "banner4"
|
||||
)
|
||||
|
||||
func (c *Client) GuildImageURL(
|
||||
guildID discord.Snowflake, img GuildImageType) string {
|
||||
|
||||
return EndpointGuilds + guildID.String() +
|
||||
"/widget.png?style=" + string(img)
|
||||
func (c *Client) GuildImageURL(guildID discord.Snowflake, img GuildImageType) string {
|
||||
return EndpointGuilds + guildID.String() + "/widget.png?style=" + string(img)
|
||||
}
|
||||
|
||||
func (c *Client) GuildImage(
|
||||
guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) {
|
||||
|
||||
func (c *Client) GuildImage(guildID discord.Snowflake, img GuildImageType) (io.ReadCloser, error) {
|
||||
r, err := c.Request("GET", c.GuildImageURL(guildID, img))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return r.Body, nil
|
||||
return r.GetBody(), nil
|
||||
}
|
||||
|
|
|
@ -27,8 +27,7 @@ func (c *Client) SendMessageComplex(
|
|||
|
||||
if len(data.Files) == 0 {
|
||||
// No files, so no need for streaming.
|
||||
return msg, c.RequestJSON(&msg, "POST", URL,
|
||||
httputil.WithJSONBody(c, data))
|
||||
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(c, data))
|
||||
}
|
||||
|
||||
writer := func(mw *multipart.Writer) error {
|
||||
|
@ -40,9 +39,10 @@ func (c *Client) SendMessageComplex(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
var body = resp.GetBody()
|
||||
defer body.Close()
|
||||
|
||||
return msg, c.DecodeStream(resp.Body, &msg)
|
||||
return msg, c.DecodeStream(body, &msg)
|
||||
}
|
||||
|
||||
const AttachmentSpoilerPrefix = "SPOILER_"
|
||||
|
@ -83,9 +83,7 @@ type ExecuteWebhookData struct {
|
|||
AvatarURL discord.URL `json:"avatar_url,omitempty"`
|
||||
}
|
||||
|
||||
func (data *ExecuteWebhookData) WriteMultipart(
|
||||
c json.Driver, body *multipart.Writer) error {
|
||||
|
||||
func (data *ExecuteWebhookData) WriteMultipart(c json.Driver, body *multipart.Writer) error {
|
||||
return writeMultipart(c, body, data, data.Files)
|
||||
}
|
||||
|
||||
|
|
|
@ -35,28 +35,21 @@ func (c *Client) CreateWebhook(
|
|||
}
|
||||
|
||||
// Webhooks requires MANAGE_WEBHOOKS.
|
||||
func (c *Client) Webhooks(
|
||||
guildID discord.Snowflake) ([]discord.Webhook, error) {
|
||||
|
||||
func (c *Client) Webhooks(guildID discord.Snowflake) ([]discord.Webhook, error) {
|
||||
var ws []discord.Webhook
|
||||
return ws, c.RequestJSON(&ws, "GET",
|
||||
EndpointGuilds+guildID.String()+"/webhooks")
|
||||
return ws, c.RequestJSON(&ws, "GET", EndpointGuilds+guildID.String()+"/webhooks")
|
||||
}
|
||||
|
||||
func (c *Client) Webhook(
|
||||
webhookID discord.Snowflake) (*discord.Webhook, error) {
|
||||
|
||||
func (c *Client) Webhook(webhookID discord.Snowflake) (*discord.Webhook, error) {
|
||||
var w *discord.Webhook
|
||||
return w, c.RequestJSON(&w, "GET",
|
||||
EndpointWebhooks+webhookID.String())
|
||||
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String())
|
||||
}
|
||||
|
||||
func (c *Client) WebhookWithToken(
|
||||
webhookID discord.Snowflake, token string) (*discord.Webhook, error) {
|
||||
|
||||
var w *discord.Webhook
|
||||
return w, c.RequestJSON(&w, "GET",
|
||||
EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
return w, c.RequestJSON(&w, "GET", EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
}
|
||||
|
||||
type ModifyWebhookData struct {
|
||||
|
@ -70,8 +63,7 @@ func (c *Client) ModifyWebhook(
|
|||
data ModifyWebhookData) (*discord.Webhook, error) {
|
||||
|
||||
var w *discord.Webhook
|
||||
return w, c.RequestJSON(&w, "PATCH",
|
||||
EndpointWebhooks+webhookID.String())
|
||||
return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String())
|
||||
}
|
||||
|
||||
func (c *Client) ModifyWebhookWithToken(
|
||||
|
@ -79,26 +71,24 @@ func (c *Client) ModifyWebhookWithToken(
|
|||
data ModifyWebhookData, token string) (*discord.Webhook, error) {
|
||||
|
||||
var w *discord.Webhook
|
||||
return w, c.RequestJSON(&w, "PATCH",
|
||||
EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
return w, c.RequestJSON(&w, "PATCH", EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteWebhook(webhookID discord.Snowflake) error {
|
||||
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String())
|
||||
}
|
||||
|
||||
func (c *Client) DeleteWebhookWithToken(
|
||||
webhookID discord.Snowflake, token string) error {
|
||||
|
||||
return c.FastRequest("DELETE",
|
||||
EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
func (c *Client) DeleteWebhookWithToken(webhookID discord.Snowflake, token string) error {
|
||||
return c.FastRequest("DELETE", EndpointWebhooks+webhookID.String()+"/"+token)
|
||||
}
|
||||
|
||||
// ExecuteWebhook sends a message to the webhook. If wait is bool, Discord will
|
||||
// wait for the message to be delivered and will return the message body. This
|
||||
// also means the returned message will only be there if wait is true.
|
||||
func (c *Client) ExecuteWebhook(
|
||||
webhookID discord.Snowflake, token string, wait bool,
|
||||
webhookID discord.Snowflake,
|
||||
token string,
|
||||
wait bool,
|
||||
data ExecuteWebhookData) (*discord.Message, error) {
|
||||
|
||||
for i, embed := range data.Embeds {
|
||||
|
@ -112,8 +102,7 @@ func (c *Client) ExecuteWebhook(
|
|||
param.Set("wait", "true")
|
||||
}
|
||||
|
||||
var URL = EndpointWebhooks + webhookID.String() + "/" + token +
|
||||
"?" + param.Encode()
|
||||
var URL = EndpointWebhooks + webhookID.String() + "/" + token + "?" + param.Encode()
|
||||
var msg *discord.Message
|
||||
|
||||
if len(data.Files) == 0 {
|
||||
|
@ -131,12 +120,13 @@ func (c *Client) ExecuteWebhook(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
var body = resp.GetBody()
|
||||
defer body.Close()
|
||||
|
||||
if !wait {
|
||||
// Since we didn't tell Discord to wait, we have nothing to parse.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return msg, c.DecodeStream(resp.Body, &msg)
|
||||
return msg, c.DecodeStream(body, &msg)
|
||||
}
|
||||
|
|
|
@ -70,22 +70,22 @@ type SessionStartLimit struct {
|
|||
ResetAfter discord.Milliseconds `json:"reset_after"`
|
||||
}
|
||||
|
||||
// GatewayURL asks Discord for a Websocket URL to the Gateway.
|
||||
func GatewayURL() (string, error) {
|
||||
// URL asks Discord for a Websocket URL to the Gateway.
|
||||
func URL() (string, error) {
|
||||
var g GatewayBotData
|
||||
|
||||
return g.URL, httputil.DefaultClient.RequestJSON(
|
||||
return g.URL, httputil.NewClient().RequestJSON(
|
||||
&g, "GET",
|
||||
EndpointGateway,
|
||||
)
|
||||
}
|
||||
|
||||
// GatewayBot fetches the Gateway URL along with extra metadata. The token
|
||||
// BotURL fetches the Gateway URL along with extra metadata. The token
|
||||
// passed in will NOT be prefixed with Bot.
|
||||
func GatewayBot(token string) (*GatewayBotData, error) {
|
||||
func BotURL(token string) (*GatewayBotData, error) {
|
||||
var g *GatewayBotData
|
||||
|
||||
return g, httputil.DefaultClient.RequestJSON(
|
||||
return g, httputil.NewClient().RequestJSON(
|
||||
&g, "GET",
|
||||
EndpointGatewayBot,
|
||||
httputil.WithHeaders(http.Header{
|
||||
|
@ -144,7 +144,7 @@ func NewGateway(token string) (*Gateway, error) {
|
|||
|
||||
// NewGatewayWithDriver connects to the Gateway and authenticates automatically.
|
||||
func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
||||
URL, err := GatewayURL()
|
||||
URL, err := URL()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to get gateway endpoint")
|
||||
}
|
||||
|
|
|
@ -46,7 +46,7 @@ func New(token string) (*Session, error) {
|
|||
s.Handler = handler.New()
|
||||
s.Client = api.NewClient(token)
|
||||
|
||||
// Open a gateway
|
||||
// Create a gateway
|
||||
g, err := gateway.NewGateway(token)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to connect to Gateway")
|
||||
|
|
|
@ -3,44 +3,69 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Retries is the default attempts to retry if the API returns an error before
|
||||
// giving up.
|
||||
// giving up. If the value is smaller than 1, then requests will retry forever.
|
||||
var Retries uint = 5
|
||||
|
||||
type Client struct {
|
||||
http.Client
|
||||
httpdriver.Client
|
||||
json.Driver
|
||||
SchemaEncoder
|
||||
|
||||
// DefaultOptions, if not nil, will be copied and prefixed on each Request.
|
||||
DefaultOptions []RequestOption
|
||||
|
||||
// OnResponse is called after every Do() call. Response might be nil if Do()
|
||||
// errors out. The error returned will override Do's if it's not nil.
|
||||
OnResponse func(httpdriver.Request, httpdriver.Response) error
|
||||
|
||||
// Default to the global Retries variable (5).
|
||||
Retries uint
|
||||
}
|
||||
|
||||
var DefaultClient = NewClient()
|
||||
// ResponseNoop is used for (*Client).OnResponse.
|
||||
func ResponseNoop(httpdriver.Request, httpdriver.Response) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewClient() Client {
|
||||
return Client{
|
||||
Client: http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
Client: httpdriver.NewClient(),
|
||||
Driver: json.Default{},
|
||||
SchemaEncoder: &DefaultSchema{},
|
||||
Retries: Retries,
|
||||
OnResponse: ResponseNoop,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) applyOptions(r httpdriver.Request, extra []RequestOption) error {
|
||||
for _, opt := range c.DefaultOptions {
|
||||
if err := opt(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for _, opt := range extra {
|
||||
if err := opt(r); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) MeanwhileMultipart(
|
||||
multipartWriter func(*multipart.Writer) error,
|
||||
method, url string, opts ...RequestOption) (*http.Response, error) {
|
||||
writer func(*multipart.Writer) error,
|
||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
|
||||
// We want to cancel the request if our bodyWriter fails
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
@ -52,7 +77,7 @@ func (c *Client) MeanwhileMultipart(
|
|||
var bgErr error
|
||||
|
||||
go func() {
|
||||
if err := multipartWriter(body); err != nil {
|
||||
if err := writer(body); err != nil {
|
||||
bgErr = err
|
||||
cancel()
|
||||
}
|
||||
|
@ -61,119 +86,122 @@ func (c *Client) MeanwhileMultipart(
|
|||
w.Close()
|
||||
}()
|
||||
|
||||
resp, err := c.RequestCtx(ctx, method, url,
|
||||
append([]RequestOption{
|
||||
WithBody(r),
|
||||
WithContentType(body.FormDataContentType()),
|
||||
}, opts...)...)
|
||||
// Prepend the multipart writer and the correct Content-Type header options.
|
||||
opts = PrependOptions(
|
||||
opts,
|
||||
WithBody(r),
|
||||
WithContentType(body.FormDataContentType()),
|
||||
)
|
||||
|
||||
resp, err := c.RequestCtx(ctx, method, url, opts...)
|
||||
if err != nil && bgErr != nil {
|
||||
if resp.Body != nil {
|
||||
resp.Body.Close()
|
||||
}
|
||||
|
||||
return nil, bgErr
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (c *Client) FastRequest(
|
||||
method, url string, opts ...RequestOption) error {
|
||||
|
||||
func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
|
||||
r, err := c.Request(method, url, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return r.Body.Close()
|
||||
return r.GetBody().Close()
|
||||
}
|
||||
|
||||
func (c *Client) RequestCtx(ctx context.Context,
|
||||
method, url string, opts ...RequestOption) (*http.Response, error) {
|
||||
func (c *Client) RequestCtxJSON(
|
||||
ctx context.Context,
|
||||
to interface{}, method, url string, opts ...RequestOption) error {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
opts = PrependOptions(opts, JSONRequest)
|
||||
|
||||
r, err := c.RequestCtx(ctx, method, url, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var body, status = r.GetBody(), r.GetStatus()
|
||||
defer body.Close()
|
||||
|
||||
// No content, working as intended (tm)
|
||||
if status == httpdriver.NoContent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.DecodeStream(body, to); err != nil {
|
||||
return JSONError{err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) RequestCtx(
|
||||
ctx context.Context,
|
||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
|
||||
req, err := c.Client.NewRequest(ctx, method, url)
|
||||
if err != nil {
|
||||
return nil, RequestError{err}
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
if err := opt(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.applyOptions(req, opts); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to apply options")
|
||||
}
|
||||
|
||||
var r *http.Response
|
||||
var r httpdriver.Response
|
||||
var status int
|
||||
|
||||
for i := uint(0); i < c.Retries; i++ {
|
||||
for i := uint(0); c.Retries < 1 || i < c.Retries; i++ {
|
||||
r, err = c.Client.Do(req)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if r.StatusCode < 200 || r.StatusCode > 299 {
|
||||
if status = r.GetStatus(); status < 200 || status > 299 {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
// Call OnResponse() even if the request failed.
|
||||
if err := c.OnResponse(req, r); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If all retries failed:
|
||||
if err != nil {
|
||||
return nil, RequestError{err}
|
||||
}
|
||||
|
||||
// Response received, but with a failure status code:
|
||||
if r.StatusCode < 200 || r.StatusCode > 299 {
|
||||
if status < 200 || status > 299 {
|
||||
// Try and parse the body.
|
||||
var body = r.GetBody()
|
||||
defer body.Close()
|
||||
|
||||
// This rarely happens, so we can (probably) make an exception for it.
|
||||
buf := bytes.Buffer{}
|
||||
buf.ReadFrom(body)
|
||||
|
||||
httpErr := &HTTPError{
|
||||
Status: r.StatusCode,
|
||||
Status: status,
|
||||
Body: buf.Bytes(),
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, httpErr
|
||||
}
|
||||
// Optionally unmarshal the error.
|
||||
c.Unmarshal(httpErr.Body, &httpErr)
|
||||
|
||||
httpErr.Body = b
|
||||
|
||||
c.Unmarshal(b, &httpErr)
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (c *Client) RequestCtxJSON(ctx context.Context,
|
||||
to interface{}, method, url string, opts ...RequestOption) error {
|
||||
|
||||
r, err := c.RequestCtx(ctx, method, url,
|
||||
append([]RequestOption{JSONRequest}, opts...)...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer r.Body.Close()
|
||||
|
||||
// No content, working as intended (tm)
|
||||
if r.StatusCode == http.StatusNoContent {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.DecodeStream(r.Body, to); err != nil {
|
||||
return JSONError{err}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Request(
|
||||
method, url string, opts ...RequestOption) (*http.Response, error) {
|
||||
|
||||
func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
return c.RequestCtx(context.Background(), method, url, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) RequestJSON(
|
||||
to interface{}, method, url string, opts ...RequestOption) error {
|
||||
|
||||
func (c *Client) RequestJSON(to interface{}, method, url string, opts ...RequestOption) error {
|
||||
return c.RequestCtxJSON(context.Background(), to, method, url, opts...)
|
||||
}
|
||||
|
|
|
@ -1,36 +0,0 @@
|
|||
package httputil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type TransportWrapper struct {
|
||||
Default http.RoundTripper
|
||||
Pre func(*http.Request) error
|
||||
Post func(*http.Request, *http.Response) error
|
||||
}
|
||||
|
||||
var _ http.RoundTripper = (*TransportWrapper)(nil)
|
||||
|
||||
func NewTransportWrapper() *TransportWrapper {
|
||||
return &TransportWrapper{
|
||||
Default: http.DefaultTransport,
|
||||
Pre: func(*http.Request) error { return nil },
|
||||
Post: func(*http.Request, *http.Response) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TransportWrapper) RoundTrip(req *http.Request) (r *http.Response, err error) {
|
||||
if err := c.Pre(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r, err = c.Default.RoundTrip(req)
|
||||
|
||||
// Call Post regardless of error:
|
||||
if postErr := c.Post(req, r); postErr != nil {
|
||||
return r, postErr
|
||||
}
|
||||
|
||||
return r, err
|
||||
}
|
103
utils/httputil/httpdriver/default.go
Normal file
103
utils/httputil/httpdriver/default.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package httpdriver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// DefaultClient implements Client and wraps around the stdlib Client.
|
||||
type DefaultClient struct {
|
||||
Client http.Client
|
||||
}
|
||||
|
||||
var _ Client = (*DefaultClient)(nil)
|
||||
|
||||
// WrapClient wraps around the standard library's http.Client and returns an
|
||||
// implementation that's compatible with the Client driver interface.
|
||||
func WrapClient(client http.Client) Client {
|
||||
return &DefaultClient{
|
||||
Client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// NewClient creates a new client around the standard library's http.Client. The
|
||||
// client will have a timeout of 10 seconds.
|
||||
func NewClient() Client {
|
||||
return WrapClient(http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *DefaultClient) NewRequest(ctx context.Context, method, url string) (Request, error) {
|
||||
r, err := http.NewRequestWithContext(ctx, method, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return (*DefaultRequest)(r), nil
|
||||
}
|
||||
|
||||
func (d *DefaultClient) Do(req Request) (Response, error) {
|
||||
// Implementations can safely assert this.
|
||||
request := req.(*DefaultRequest)
|
||||
|
||||
r, err := d.Client.Do((*http.Request)(request))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return (*DefaultResponse)(r), nil
|
||||
}
|
||||
|
||||
// DefaultRequest wraps around the stdlib Request and satisfies the Request
|
||||
// interface.
|
||||
type DefaultRequest http.Request
|
||||
|
||||
var _ Request = (*DefaultRequest)(nil)
|
||||
|
||||
func (r *DefaultRequest) GetPath() string {
|
||||
return r.URL.Path
|
||||
}
|
||||
|
||||
func (r *DefaultRequest) GetContext() context.Context {
|
||||
return (*http.Request)(r).Context()
|
||||
}
|
||||
|
||||
func (r *DefaultRequest) AddQuery(values url.Values) {
|
||||
var qs = r.URL.Query()
|
||||
for k, v := range values {
|
||||
qs[k] = append(qs[k], v...)
|
||||
}
|
||||
|
||||
r.URL.RawQuery = qs.Encode()
|
||||
}
|
||||
|
||||
func (r *DefaultRequest) AddHeader(header http.Header) {
|
||||
for key, values := range header {
|
||||
r.Header[key] = append(r.Header[key], values...)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DefaultRequest) WithBody(body io.ReadCloser) {
|
||||
r.Body = body
|
||||
}
|
||||
|
||||
// DefaultResponse wraps around the stdlib Response and satisfies the Response
|
||||
// interface.
|
||||
type DefaultResponse http.Response
|
||||
|
||||
var _ Response = (*DefaultResponse)(nil)
|
||||
|
||||
func (r *DefaultResponse) GetStatus() int {
|
||||
return r.StatusCode
|
||||
}
|
||||
|
||||
func (r *DefaultResponse) GetHeader() http.Header {
|
||||
return r.Header
|
||||
}
|
||||
|
||||
func (r *DefaultResponse) GetBody() io.ReadCloser {
|
||||
return r.Body
|
||||
}
|
55
utils/httputil/httpdriver/driver.go
Normal file
55
utils/httputil/httpdriver/driver.go
Normal file
|
@ -0,0 +1,55 @@
|
|||
// Package httpdriver provides interfaces and implementations of a simple HTTP
|
||||
// client.
|
||||
package httpdriver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// NoContent is the status code for HTTP 204, or http.StatusNoContent.
|
||||
const NoContent = 204
|
||||
|
||||
// Client is a generic interface used as an adapter allowing for custom HTTP
|
||||
// client implementations, such as fasthttp.
|
||||
type Client interface {
|
||||
NewRequest(ctx context.Context, method, url string) (Request, error)
|
||||
Do(req Request) (Response, error)
|
||||
}
|
||||
|
||||
// Request is a generic interface for a normal HTTP request. It should be
|
||||
// constructed using (Requester).NewRequest().
|
||||
type Request interface {
|
||||
// GetPath should return the URL path, for example "/endpoint/thing".
|
||||
GetPath() string
|
||||
// GetContext should return the same context that's passed into NewRequest.
|
||||
// For implementations that don't support this, it can remove a
|
||||
// context.Background().
|
||||
GetContext() context.Context
|
||||
// AddHeader appends headers.
|
||||
AddHeader(http.Header)
|
||||
// AddQuery appends URL query values.
|
||||
AddQuery(url.Values)
|
||||
// WithBody should automatically close the ReadCloser on finish. This is
|
||||
// similar to the stdlib's Request behavior.
|
||||
WithBody(io.ReadCloser)
|
||||
}
|
||||
|
||||
// Response is returned from (Requester).DoContext().
|
||||
type Response interface {
|
||||
GetStatus() int
|
||||
GetHeader() http.Header
|
||||
// Body's ReadCloser will always be closed when done, unless DoContext()
|
||||
// returns an error.
|
||||
GetBody() io.ReadCloser
|
||||
}
|
||||
|
||||
// OptHeader returns the response header, or nil if from is nil.
|
||||
func OptHeader(from Response) http.Header {
|
||||
if from == nil {
|
||||
return nil
|
||||
}
|
||||
return from.GetHeader()
|
||||
}
|
|
@ -4,86 +4,93 @@ import (
|
|||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/diamondburned/arikawa/utils/httputil/httpdriver"
|
||||
"github.com/diamondburned/arikawa/utils/json"
|
||||
)
|
||||
|
||||
type RequestOption func(*http.Request) error
|
||||
type RequestOption func(httpdriver.Request) error
|
||||
|
||||
func JSONRequest(r *http.Request) error {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
func PrependOptions(opts []RequestOption, prepend ...RequestOption) []RequestOption {
|
||||
if len(opts) == 0 {
|
||||
return prepend
|
||||
}
|
||||
return append(prepend, opts...)
|
||||
}
|
||||
|
||||
func JSONRequest(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func MultipartRequest(r *http.Request) error {
|
||||
r.Header.Set("Content-Type", "multipart/form-data")
|
||||
func MultipartRequest(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Content-Type": {"multipart/form-data"},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithHeaders(headers http.Header) RequestOption {
|
||||
return func(r *http.Request) error {
|
||||
for key, values := range headers {
|
||||
r.Header[key] = append(r.Header[key], values...)
|
||||
}
|
||||
return func(r httpdriver.Request) error {
|
||||
r.AddHeader(headers)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithContentType(ctype string) RequestOption {
|
||||
return func(r *http.Request) error {
|
||||
r.Header.Set("Content-Type", ctype)
|
||||
return func(r httpdriver.Request) error {
|
||||
r.AddHeader(http.Header{
|
||||
"Content-Type": {ctype},
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithSchema(schema SchemaEncoder, v interface{}) RequestOption {
|
||||
return func(r *http.Request) error {
|
||||
return func(r httpdriver.Request) error {
|
||||
params, err := schema.Encode(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var qs = r.URL.Query()
|
||||
for k, v := range params {
|
||||
qs[k] = append(qs[k], v...)
|
||||
}
|
||||
|
||||
r.URL.RawQuery = qs.Encode()
|
||||
r.AddQuery(params)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithBody(body io.ReadCloser) RequestOption {
|
||||
return func(r *http.Request) error {
|
||||
// tee := io.TeeReader(body, os.Stderr)
|
||||
// r.Body = ioutil.NopCloser(tee)
|
||||
r.Body = body
|
||||
r.ContentLength = -1
|
||||
return func(r httpdriver.Request) error {
|
||||
r.WithBody(body)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// WithJSONBody inserts a JSON body into the request. This ignores JSON errors.
|
||||
func WithJSONBody(json json.Driver, v interface{}) RequestOption {
|
||||
if v == nil {
|
||||
return func(*http.Request) error {
|
||||
return func(httpdriver.Request) error {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
var rp, wp = io.Pipe()
|
||||
|
||||
go func() {
|
||||
err = json.EncodeStream(wp, v)
|
||||
json.EncodeStream(wp, v)
|
||||
wp.Close()
|
||||
}()
|
||||
|
||||
return func(r *http.Request) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return func(r httpdriver.Request) error {
|
||||
// TODO: maybe do something to this?
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Body = rp
|
||||
r.AddHeader(http.Header{
|
||||
"Content-Type": {"application/json"},
|
||||
})
|
||||
r.WithBody(rp)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue