1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-12-11 16:05:00 +00:00
arikawa/utils/wsutil/ws.go

201 lines
4.6 KiB
Go
Raw Normal View History

2020-01-15 04:56:50 +00:00
// Package wsutil provides abstractions around the Websocket, including rate
// limits.
2020-01-09 05:24:45 +00:00
package wsutil
import (
"context"
"log"
"sync"
2020-01-09 05:24:45 +00:00
"time"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
var (
// WSTimeout is the timeout for connecting and writing to the Websocket,
// before Gateway cancels and fails.
2020-08-20 21:15:52 +00:00
WSTimeout = 30 * time.Second
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSError is the default error handler
WSError = func(err error) { log.Println("Gateway error:", err) }
// WSDebug is used for extra debug logging. This is expected to behave
// similarly to log.Println().
WSDebug = func(v ...interface{}) {}
)
2020-01-09 05:24:45 +00:00
type Event struct {
Data []byte
// Error is non-nil if Data is nil.
Error error
}
// Websocket is a wrapper around a websocket Conn with thread safety and rate
// limiting for sending and throttling.
2020-01-09 05:24:45 +00:00
type Websocket struct {
mutex sync.Mutex
conn Connection
addr string
closed bool
2020-11-01 18:09:41 +00:00
sendLimiter *rate.Limiter
dialLimiter *rate.Limiter
// Timeout is the default timeout used if a context with no deadline is
// given to Dial.
//
// It must not be changed after the Websocket is used once.
Timeout time.Duration
2020-01-09 05:24:45 +00:00
}
// New creates a default Websocket with the given address.
func New(addr string) *Websocket {
return NewCustom(NewConn(), addr)
2020-01-15 04:43:34 +00:00
}
2020-01-09 05:24:45 +00:00
2020-01-15 04:43:34 +00:00
// NewCustom creates a new undialed Websocket.
func NewCustom(conn Connection, addr string) *Websocket {
return &Websocket{
conn: conn,
addr: addr,
closed: true,
2020-01-15 04:43:34 +00:00
2020-11-01 18:09:41 +00:00
sendLimiter: NewSendLimiter(),
dialLimiter: NewDialLimiter(),
2020-11-01 18:09:41 +00:00
Timeout: WSTimeout,
2020-01-09 05:24:45 +00:00
}
}
// Dial waits until the rate limiter allows then dials the websocket.
//
// If the passed context has no deadline, Dial will wrap it in a
// context.WithTimeout using ws.Timeout as timeout.
func (ws *Websocket) Dial(ctx context.Context) error {
if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, ws.Timeout)
defer cancel()
}
2020-11-01 18:09:41 +00:00
if err := ws.dialLimiter.Wait(ctx); err != nil {
2020-01-15 04:43:34 +00:00
// Expired, fatal error
2020-05-16 21:14:49 +00:00
return errors.Wrap(err, "failed to wait")
2020-01-15 04:43:34 +00:00
}
2020-01-09 05:24:45 +00:00
ws.mutex.Lock()
defer ws.mutex.Unlock()
if !ws.closed {
WSDebug("Old connection not yet closed while dialog; closing it.")
ws.conn.Close()
}
if err := ws.conn.Dial(ctx, ws.addr); err != nil {
2020-05-16 21:14:49 +00:00
return errors.Wrap(err, "failed to dial")
2020-01-09 05:24:45 +00:00
}
2020-01-15 04:43:34 +00:00
ws.closed = false
2020-11-01 18:09:41 +00:00
// Reset the send limiter.
ws.sendLimiter = NewSendLimiter()
2020-01-15 04:43:34 +00:00
return nil
2020-01-09 05:24:45 +00:00
}
// Listen returns the inner event channel or nil if the Websocket connection is
// not alive.
2020-01-09 05:24:45 +00:00
func (ws *Websocket) Listen() <-chan Event {
ws.mutex.Lock()
defer ws.mutex.Unlock()
if ws.closed {
return nil
}
return ws.conn.Listen()
2020-01-09 05:24:45 +00:00
}
// Send sends b over the Websocket without a timeout.
func (ws *Websocket) Send(b []byte) error {
return ws.SendCtx(context.Background(), b)
}
// SendCtx sends b over the Websocket with a deadline. It closes the internal
// Websocket if the Send method errors out.
func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
2020-11-01 17:44:02 +00:00
WSDebug("Waiting for the send rate limiter...")
2020-11-01 18:09:41 +00:00
if err := ws.sendLimiter.Wait(ctx); err != nil {
2020-11-01 17:44:02 +00:00
WSDebug("Send rate limiter timed out.")
2020-01-09 05:24:45 +00:00
return errors.Wrap(err, "SendLimiter failed")
}
2021-06-10 23:48:32 +00:00
WSDebug("Send has passed the rate limiting. Waiting on mutex.")
2020-11-01 17:44:02 +00:00
ws.mutex.Lock()
defer ws.mutex.Unlock()
2020-11-01 17:44:02 +00:00
WSDebug("Mutex lock acquired.")
if ws.closed {
return ErrWebsocketClosed
}
if err := ws.conn.Send(ctx, b); err != nil {
// We need to clean up ourselves if things are erroring out.
WSDebug("Conn: Error while sending; closing the connection. Error:", err)
ws.close(false)
return err
}
return nil
}
// Close closes the websocket connection. It assumes that the Websocket is
// closed even when it returns an error. If the Websocket was already closed
// before, ErrWebsocketClosed will be returned.
func (ws *Websocket) Close() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.close(false)
}
func (ws *Websocket) CloseGracefully() error {
WSDebug("Conn: Acquiring mutex lock to close...")
ws.mutex.Lock()
defer ws.mutex.Unlock()
WSDebug("Conn: Write mutex acquired")
return ws.close(true)
}
// close closes the Websocket without acquiring the mutex. Refer to Close for
// more information.
func (ws *Websocket) close(graceful bool) error {
if ws.closed {
WSDebug("Conn: Websocket is already closed.")
return ErrWebsocketClosed
}
ws.closed = true
if graceful {
WSDebug("Conn: Closing gracefully")
return ws.conn.CloseGracefully()
}
WSDebug("Conn: Closing")
return ws.conn.Close()
2020-01-09 05:24:45 +00:00
}