mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-16 03:44:26 +00:00
231 lines
5.2 KiB
Go
231 lines
5.2 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/diamondburned/arikawa/v3/utils/json"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
// OpCode is the type for websocket Op codes. Op codes less than 0 are
|
|
// internal Op codes and should usually be ignored.
|
|
type OpCode int
|
|
|
|
// CloseEvent is an event that is given from wsutil when the websocket is
|
|
// closed.
|
|
type CloseEvent struct {
|
|
// Err is the underlying error.
|
|
Err error
|
|
// Code is the websocket close code, if any.
|
|
Code int
|
|
}
|
|
|
|
// Unwrap returns err.Err.
|
|
func (e *CloseEvent) Unwrap() error { return e.Err }
|
|
|
|
// Error formats the CloseEvent. A CloseEvent is also an error.
|
|
func (e *CloseEvent) Error() string {
|
|
return fmt.Sprintf("websocket closed, reason: %s", e.Err)
|
|
}
|
|
|
|
// Op implements Event. It returns -1.
|
|
func (e *CloseEvent) Op() OpCode { return -1 }
|
|
|
|
// EventType implements Event. It returns an emty string.
|
|
func (e *CloseEvent) EventType() EventType { return "__ws.CloseEvent" }
|
|
|
|
// EnableRawEvents, if true, will cause ws to generate a RawEvent for each
|
|
// regular Event. It should only be used for debugging.
|
|
var EnableRawEvents = false
|
|
|
|
// RawEvent is used if EnableRawEvents is true.
|
|
type RawEvent struct {
|
|
json.Raw
|
|
OriginalCode OpCode `json:"-"`
|
|
OriginalType EventType `json:"-"`
|
|
}
|
|
|
|
// Op implements Event. It returns -1.
|
|
func (e *RawEvent) Op() OpCode { return -1 }
|
|
|
|
// EventType implements Event. It returns an emty string.
|
|
func (e *RawEvent) EventType() EventType { return "__ws.RawEvent" }
|
|
|
|
// EventType is a type for event types, which is the "t" field in the payload.
|
|
type EventType string
|
|
|
|
// Event describes an Event data that comes from a gateway Operation.
|
|
type Event interface {
|
|
Op() OpCode
|
|
EventType() EventType
|
|
}
|
|
|
|
// OpFunc is a constructor function for an Operation.
|
|
type OpFunc func() Event
|
|
|
|
// OpUnmarshalers contains a map of event constructor function.
|
|
type OpUnmarshalers struct {
|
|
r map[opFuncID]OpFunc
|
|
}
|
|
|
|
type opFuncID struct {
|
|
Op OpCode `json:"op"`
|
|
T EventType `json:"t"`
|
|
}
|
|
|
|
// NewOpUnmarshalers creates a nwe OpUnmarshalers instance from the given
|
|
// constructor functions.
|
|
func NewOpUnmarshalers(funcs ...OpFunc) OpUnmarshalers {
|
|
m := OpUnmarshalers{r: make(map[opFuncID]OpFunc)}
|
|
m.Add(funcs...)
|
|
return m
|
|
}
|
|
|
|
// Each iterates over the marshaler map.
|
|
func (m OpUnmarshalers) Each(f func(OpCode, EventType, OpFunc) (done bool)) {
|
|
for id, fn := range m.r {
|
|
if f(id.Op, id.T, fn) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add adds the given functions into the unmarshaler registry.
|
|
func (m OpUnmarshalers) Add(funcs ...OpFunc) {
|
|
for _, fn := range funcs {
|
|
ev := fn()
|
|
id := opFuncID{
|
|
Op: ev.Op(),
|
|
T: ev.EventType(),
|
|
}
|
|
|
|
m.r[id] = fn
|
|
}
|
|
}
|
|
|
|
// Lookup searches the OpMarshalers map for the given constructor function.
|
|
func (m OpUnmarshalers) Lookup(op OpCode, t EventType) OpFunc {
|
|
return m.r[opFuncID{op, t}]
|
|
}
|
|
|
|
// Op is a gateway Operation.
|
|
type Op struct {
|
|
Code OpCode `json:"op"`
|
|
Data Event `json:"d,omitempty"`
|
|
|
|
// Type is only for gateway dispatch events.
|
|
Type EventType `json:"t,omitempty"`
|
|
// Sequence is only for gateway dispatch events (Op 0).
|
|
Sequence int64 `json:"s,omitempty"`
|
|
}
|
|
|
|
// UnknownEventError is required by HandleOp if an event is encountered that is
|
|
// not known. Internally, unknown events are logged and ignored. It is not a
|
|
// fatal error.
|
|
type UnknownEventError struct {
|
|
Op OpCode
|
|
Type EventType
|
|
}
|
|
|
|
// Error formats the unknown event error to with the event name and payload
|
|
func (err UnknownEventError) Error() string {
|
|
return fmt.Sprintf("unknown op %d, event %s", err.Op, err.Type)
|
|
}
|
|
|
|
// IsBrokenConnection returns true if the error is a broken connection error.
|
|
func IsUnknownEvent(err error) bool {
|
|
var uevent *UnknownEventError
|
|
return errors.As(err, &uevent)
|
|
}
|
|
|
|
// ReadOps reads maximum n Ops and accumulate them into a slice.
|
|
func ReadOps(ctx context.Context, ch <-chan Op, n int) ([]Op, error) {
|
|
ops := make([]Op, 0, n)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return ops, ctx.Err()
|
|
case op := <-ch:
|
|
ops = append(ops, op)
|
|
if len(ops) == n {
|
|
return ops, nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ReadOp reads a single Op.
|
|
func ReadOp(ctx context.Context, ch <-chan Op) (Op, error) {
|
|
select {
|
|
case <-ctx.Done():
|
|
return Op{}, ctx.Err()
|
|
case op := <-ch:
|
|
return op, nil
|
|
}
|
|
}
|
|
|
|
// Broadcaster is primarily used for debugging.
|
|
type Broadcaster struct {
|
|
src <-chan Op
|
|
dst map[chan<- Op]struct{}
|
|
mut sync.Mutex
|
|
void bool
|
|
}
|
|
|
|
// NewBroadcaster creates a new broadcaster.
|
|
func NewBroadcaster(src <-chan Op) *Broadcaster {
|
|
return &Broadcaster{
|
|
src: src,
|
|
dst: make(map[chan<- Op]struct{}),
|
|
}
|
|
}
|
|
|
|
// Start starts the broadcasting loop.
|
|
func (b *Broadcaster) Start() {
|
|
b.mut.Lock()
|
|
if b.void {
|
|
panic("Start called on voided Broadcaster")
|
|
}
|
|
b.mut.Unlock()
|
|
|
|
go func() {
|
|
for op := range b.src {
|
|
b.mut.Lock()
|
|
|
|
for ch := range b.dst {
|
|
ch <- op
|
|
}
|
|
|
|
b.mut.Unlock()
|
|
}
|
|
|
|
b.mut.Lock()
|
|
b.void = true
|
|
|
|
for ch := range b.dst {
|
|
close(ch)
|
|
}
|
|
|
|
b.mut.Unlock()
|
|
}()
|
|
}
|
|
|
|
// Subscribe subscribes the given channel
|
|
func (b *Broadcaster) Subscribe(ch chan<- Op) {
|
|
b.mut.Lock()
|
|
if b.void {
|
|
panic("Subscribe called on voided Broadcaster")
|
|
}
|
|
b.dst[ch] = struct{}{}
|
|
b.mut.Unlock()
|
|
}
|
|
|
|
// NewSubscribed creates a newly subscribed Op channel.
|
|
func (b *Broadcaster) NewSubscribed() <-chan Op {
|
|
ch := make(chan Op, 1)
|
|
b.Subscribe(ch)
|
|
return ch
|
|
}
|