mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-21 12:07:14 +00:00
127 lines
2.7 KiB
Go
127 lines
2.7 KiB
Go
// Package wsutil provides abstractions around the Websocket, including rate
|
|
// limits.
|
|
package wsutil
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
var (
|
|
// WSTimeout is the timeout for connecting and writing to the Websocket,
|
|
// before Gateway cancels and fails.
|
|
WSTimeout = time.Minute
|
|
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
|
// make space for the first Event: Ready or Resumed.
|
|
WSBuffer = 10
|
|
// WSError is the default error handler
|
|
WSError = func(err error) { log.Println("Gateway error:", err) }
|
|
// WSExtraReadTimeout is the duration to be added to Hello, as a read
|
|
// timeout for the websocket.
|
|
WSExtraReadTimeout = time.Second
|
|
// WSDebug is used for extra debug logging. This is expected to behave
|
|
// similarly to log.Println().
|
|
WSDebug = func(v ...interface{}) {}
|
|
)
|
|
|
|
type Event struct {
|
|
Data []byte
|
|
|
|
// Error is non-nil if Data is nil.
|
|
Error error
|
|
}
|
|
|
|
type Websocket struct {
|
|
Conn Connection
|
|
Addr string
|
|
|
|
// Timeout for connecting and writing to the Websocket, uses default
|
|
// WSTimeout (global).
|
|
Timeout time.Duration
|
|
|
|
SendLimiter *rate.Limiter
|
|
DialLimiter *rate.Limiter
|
|
}
|
|
|
|
func New(addr string) *Websocket {
|
|
return NewCustom(NewConn(), addr)
|
|
}
|
|
|
|
// NewCustom creates a new undialed Websocket.
|
|
func NewCustom(conn Connection, addr string) *Websocket {
|
|
return &Websocket{
|
|
Conn: conn,
|
|
Addr: addr,
|
|
|
|
Timeout: WSTimeout,
|
|
|
|
SendLimiter: NewSendLimiter(),
|
|
DialLimiter: NewDialLimiter(),
|
|
}
|
|
}
|
|
|
|
func (ws *Websocket) Dial(ctx context.Context) error {
|
|
if ws.Timeout > 0 {
|
|
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
|
|
defer cancel()
|
|
|
|
ctx = tctx
|
|
}
|
|
|
|
if err := ws.DialLimiter.Wait(ctx); err != nil {
|
|
// Expired, fatal error
|
|
return errors.Wrap(err, "failed to wait")
|
|
}
|
|
|
|
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
|
|
return errors.Wrap(err, "failed to dial")
|
|
}
|
|
|
|
// Reset the SendLimiter:
|
|
ws.SendLimiter = NewSendLimiter()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ws *Websocket) Listen() <-chan Event {
|
|
return ws.Conn.Listen()
|
|
}
|
|
|
|
func (ws *Websocket) Send(b []byte) error {
|
|
return ws.SendCtx(context.Background(), b)
|
|
}
|
|
|
|
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
|
|
if err := ws.SendLimiter.Wait(ctx); err != nil {
|
|
return errors.Wrap(err, "SendLimiter failed")
|
|
}
|
|
|
|
return ws.Conn.Send(ctx, b)
|
|
}
|
|
|
|
func (ws *Websocket) Close() error {
|
|
return ws.Conn.Close()
|
|
}
|
|
|
|
func InjectValues(rawurl string, values url.Values) string {
|
|
u, err := url.Parse(rawurl)
|
|
if err != nil {
|
|
// Unknown URL, return as-is.
|
|
return rawurl
|
|
}
|
|
|
|
// Append additional parameters:
|
|
var q = u.Query()
|
|
for k, v := range values {
|
|
q[k] = append(q[k], v...)
|
|
}
|
|
|
|
u.RawQuery = q.Encode()
|
|
return u.String()
|
|
}
|