1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-12-01 03:03:48 +00:00
arikawa/internal/wsutil/conn.go

180 lines
3.6 KiB
Go
Raw Normal View History

2020-01-09 05:24:45 +00:00
package wsutil
import (
"compress/zlib"
"context"
2020-01-29 03:54:22 +00:00
"io"
2020-01-09 05:24:45 +00:00
"io/ioutil"
"net/http"
"sync"
2020-01-09 05:24:45 +00:00
2020-01-29 03:54:22 +00:00
stderr "errors"
2020-01-15 18:32:54 +00:00
"github.com/diamondburned/arikawa/internal/json"
2020-01-09 05:24:45 +00:00
"github.com/pkg/errors"
"nhooyr.io/websocket"
)
var WSReadLimit int64 = 8192000 // 8 MiB
2020-01-09 05:24:45 +00:00
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself.
type Connection interface {
2020-01-15 04:43:34 +00:00
// Dial dials the address (string). Context needs to be passed in for
// timeout. This method should also be re-usable after Close is called.
Dial(context.Context, string) error
2020-01-09 05:24:45 +00:00
// Listen sends over events constantly. Error will be non-nil if Data is
// nil, so check for Error first.
Listen() <-chan Event
// Send allows the caller to send bytes. Context needs to be passed in order
// to re-use the context that's already used for the limiter.
Send(context.Context, []byte) error
// Close should close the websocket connection. The connection will not be
// reused.
// If error is nil, the connection should close with a StatusNormalClosure
// (1000). If not, it should close with a StatusProtocolError (1002).
Close(err error) error
}
// Conn is the default Websocket connection. It compresses all payloads using
// zlib.
type Conn struct {
Conn *websocket.Conn
2020-01-09 05:24:45 +00:00
json.Driver
mut sync.Mutex
2020-01-09 05:24:45 +00:00
events chan Event
}
var _ Connection = (*Conn)(nil)
func NewConn(driver json.Driver) *Conn {
return &Conn{
2020-01-17 22:29:13 +00:00
Driver: driver,
2020-02-02 22:12:54 +00:00
events: make(chan Event),
2020-01-09 05:24:45 +00:00
}
}
func (c *Conn) Dial(ctx context.Context, addr string) error {
var err error
headers := http.Header{}
headers.Set("Accept-Encoding", "zlib") // enable
c.mut.Lock()
defer c.mut.Unlock()
2020-01-09 05:24:45 +00:00
c.Conn, _, err = websocket.Dial(ctx, addr, &websocket.DialOptions{
HTTPHeader: headers,
})
2020-02-11 17:23:42 +00:00
if err != nil {
return errors.Wrap(err, "Failed to dial WS")
}
c.Conn.SetReadLimit(WSReadLimit)
2020-02-02 22:12:54 +00:00
c.events = make(chan Event)
2020-01-29 03:54:22 +00:00
c.readLoop()
2020-01-09 05:24:45 +00:00
return err
}
func (c *Conn) Listen() <-chan Event {
return c.events
}
2020-01-29 03:54:22 +00:00
func (c *Conn) readLoop() {
2020-02-02 22:12:54 +00:00
conn := c.Conn
go func() {
2020-01-29 03:54:22 +00:00
defer close(c.events)
for {
2020-02-02 22:12:54 +00:00
b, err := readAll(conn, context.Background())
if err != nil {
2020-01-29 03:54:22 +00:00
// Is the error an EOF?
if stderr.Is(err, io.EOF) {
// Yes it is, exit.
return
}
// Check if the error is a fatal one
if code := websocket.CloseStatus(err); code > -1 {
2020-02-02 22:12:54 +00:00
// Is the exit normal?
if code == websocket.StatusNormalClosure {
return
}
2020-01-16 03:28:21 +00:00
}
2020-02-02 22:12:54 +00:00
// Unusual error; log:
2020-01-29 03:54:22 +00:00
c.events <- Event{nil, errors.Wrap(err, "WS error")}
2020-02-02 22:12:54 +00:00
return
2020-01-09 05:24:45 +00:00
}
2020-01-29 03:54:22 +00:00
c.events <- Event{b, nil}
2020-01-09 05:24:45 +00:00
}
}()
2020-01-09 05:24:45 +00:00
}
2020-02-02 22:12:54 +00:00
func readAll(c *websocket.Conn, ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx)
2020-01-09 05:24:45 +00:00
if err != nil {
2020-01-16 03:28:21 +00:00
return nil, err
2020-01-09 05:24:45 +00:00
}
if t == websocket.MessageBinary {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
2020-02-02 22:12:54 +00:00
c.CloseRead(ctx)
2020-01-09 05:24:45 +00:00
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
}
defer z.Close()
r = z
}
b, err := ioutil.ReadAll(r)
2020-01-09 05:24:45 +00:00
if err != nil {
2020-02-02 22:12:54 +00:00
c.CloseRead(ctx)
return nil, err
2020-01-09 05:24:45 +00:00
}
return b, nil
}
2020-01-09 05:24:45 +00:00
func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream
return c.Conn.Write(ctx, websocket.MessageText, b)
2020-01-09 05:24:45 +00:00
}
func (c *Conn) Close(err error) error {
// Wait for the read loop to exit after exiting.
defer c.close()
2020-01-09 05:24:45 +00:00
if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "")
}
var msg = err.Error()
if len(msg) > 125 {
msg = msg[:125] // truncate
}
return c.Conn.Close(websocket.StatusProtocolError, msg)
}
func (c *Conn) close() {
c.mut.Lock()
defer c.mut.Unlock()
<-c.events
c.events = nil
// Set the connection to nil.
c.Conn = nil
}