1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-17 20:32:48 +00:00

Merge branch 'master' of ssh://github.com/diamondburned/arikawa

This commit is contained in:
diamondburned (Forefront) 2020-01-15 20:44:30 -08:00
commit 478b66fffa
8 changed files with 121 additions and 97 deletions

View file

@ -1,3 +1,13 @@
# arikawa # arikawa
A Golang library for the Discord API. A Golang library for the Discord API.
## Testing
The package includes integration tests that require `$BOT_TOKEN`. To run these
tests, do
```sh
export BOT_TOKEN="<BOT_TOKEN>"
go test -tags integration ./...
```

View file

@ -60,6 +60,25 @@ type ResumeData struct {
Sequence int64 `json:"seq"` Sequence int64 `json:"seq"`
} }
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
var (
ses = g.SessionID
seq = g.Sequence.Get()
)
if ses == "" || seq == 0 {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
})
}
// HeartbeatData is the last sequence number to be sent. // HeartbeatData is the last sequence number to be sent.
type HeartbeatData int type HeartbeatData int

View file

@ -9,7 +9,6 @@ package gateway
import ( import (
"context" "context"
"log"
"net/url" "net/url"
"runtime" "runtime"
"time" "time"
@ -144,6 +143,13 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
// Close closes the underlying Websocket connection. // Close closes the underlying Websocket connection.
func (g *Gateway) Close() error { func (g *Gateway) Close() error {
// Stop the pacemaker
g.Pacemaker.Stop()
// Stop the event handler
defer close(g.handler)
// Stop the Websocket
return g.WS.Close(nil) return g.WS.Close(nil)
} }
@ -155,25 +161,6 @@ func (g *Gateway) Reconnect() error {
return g.connect() return g.connect()
} }
// Resume sends to the Websocket a Resume OP, but it doesn't actually resume
// from a dead connection. Start() resumes from a dead connection.
func (g *Gateway) Resume() error {
var (
ses = g.SessionID
seq = g.Sequence.Get()
)
if ses == "" || seq == 0 {
return ErrMissingForResume
}
return g.Send(ResumeOP, ResumeData{
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
})
}
// Start authenticates with the websocket, or resume from a dead Websocket // Start authenticates with the websocket, or resume from a dead Websocket
// connection. This function doesn't block. To block until a fatal error, use // connection. This function doesn't block. To block until a fatal error, use
// Wait(). // Wait().
@ -187,48 +174,6 @@ func (g *Gateway) Start() error {
return errors.Wrap(err, "Error at Hello") return errors.Wrap(err, "Error at Hello")
} }
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID == "" {
// SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil {
return errors.Wrap(err, "Failed to identify")
}
// We should now expect a Ready event.
var ready ReadyEvent
p, err := AssertEvent(g, <-ch, DispatchOP, &ready)
if err != nil {
return errors.Wrap(err, "Error at Ready")
}
// We now also have the SessionID and the SequenceID
g.SessionID = ready.SessionID
g.Sequence.Set(p.Sequence)
// Send the event away
g.Events <- &ready
} else {
if err := g.Resume(); err != nil {
return errors.Wrap(err, "Failed to resume")
}
// We should now expect a Resumed event.
var resumed ResumedEvent
_, err := AssertEvent(g, <-ch, DispatchOP, &resumed)
if err != nil {
return errors.Wrap(err, "Error at Resumed")
}
// Send the event away
g.Events <- &resumed
}
// Start the event handler
g.handler = make(chan struct{})
go g.handleWS(g.handler)
// Start the pacemaker with the heartrate received from Hello // Start the pacemaker with the heartrate received from Hello
g.Pacemaker = &Pacemaker{ g.Pacemaker = &Pacemaker{
Heartrate: hello.HeartbeatInterval.Duration(), Heartrate: hello.HeartbeatInterval.Duration(),
@ -238,11 +183,27 @@ func (g *Gateway) Start() error {
// Pacemaker dies here, only when it's fatal. // Pacemaker dies here, only when it's fatal.
g.paceDeath = g.Pacemaker.StartAsync() g.paceDeath = g.Pacemaker.StartAsync()
// Send Discord either the Identify packet (if it's a fresh connection), or
// a Resume packet (if it's a dead connection).
if g.SessionID == "" {
// SessionID is empty, so this is a completely new session.
if err := g.Identify(); err != nil {
return errors.Wrap(err, "Failed to identify")
}
} else {
if err := g.Resume(); err != nil {
return errors.Wrap(err, "Failed to resume")
}
}
// Start the event handler
g.handler = make(chan struct{})
go g.handleWS(g.handler)
return nil return nil
} }
func (g *Gateway) Wait() error { func (g *Gateway) Wait() error {
defer close(g.handler)
return <-g.paceDeath return <-g.paceDeath
} }
@ -263,7 +224,7 @@ func (g *Gateway) handleWS(stop <-chan struct{}) {
// Handle the event // Handle the event
if err := HandleEvent(g, ev.Data); err != nil { if err := HandleEvent(g, ev.Data); err != nil {
g.ErrorLog(errors.Wrap(ev.Error, "WS handler error")) g.ErrorLog(errors.Wrap(err, "WS handler error"))
} }
} }
} }
@ -288,8 +249,6 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
return errors.Wrap(err, "Failed to encode payload") return errors.Wrap(err, "Failed to encode payload")
} }
log.Println("->", len(b), string(b))
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel() defer cancel()

View file

@ -15,7 +15,7 @@ func TestIntegration(t *testing.T) {
} }
WSError = func(err error) { WSError = func(err error) {
t.Error("WS:", err) log.Println(err)
} }
var gateway *Gateway var gateway *Gateway
@ -43,8 +43,15 @@ func TestIntegration(t *testing.T) {
t.Fatal("Failed to reconnect:", err) t.Fatal("Failed to reconnect:", err)
} }
/* TODO: this isn't true, as Discord would keep sending Invalid Sessions.
resumed, ok := (<-gateway.Events).(*ResumedEvent) resumed, ok := (<-gateway.Events).(*ResumedEvent)
if !ok { if !ok {
t.Fatal("Event received is not of type Resumed:", resumed) t.Fatal("Event received is not of type Resumed:", resumed)
} }
*/
ready, ok = (<-gateway.Events).(*ReadyEvent)
if !ok {
t.Fatal("Event received is not of type Ready:", ready)
}
} }

View file

@ -2,7 +2,6 @@ package gateway
import ( import (
"fmt" "fmt"
"log"
"github.com/diamondburned/arikawa/internal/json" "github.com/diamondburned/arikawa/internal/json"
"github.com/diamondburned/arikawa/internal/wsutil" "github.com/diamondburned/arikawa/internal/wsutil"
@ -45,8 +44,6 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
return nil, ev.Error return nil, ev.Error
} }
log.Println("<-", string(ev.Data))
var op *OP var op *OP
if err := driver.Unmarshal(ev.Data, &op); err != nil { if err := driver.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "Failed to decode payload") return nil, errors.Wrap(err, "Failed to decode payload")
@ -103,6 +100,10 @@ func HandleEvent(g *Gateway, data []byte) error {
return errors.Wrap(err, "OP error") return errors.Wrap(err, "OP error")
} }
return HandleOP(g, op)
}
func HandleOP(g *Gateway, op *OP) error {
if g.OP != nil { if g.OP != nil {
g.OP <- op g.OP <- op
} }
@ -130,7 +131,9 @@ func HandleEvent(g *Gateway, data []byte) error {
case DispatchOP: case DispatchOP:
// Set the sequence // Set the sequence
if op.Sequence > 0 {
g.Sequence.Set(op.Sequence) g.Sequence.Set(op.Sequence)
}
// Check if we know the event // Check if we know the event
fn, ok := EventCreator[op.EventName] fn, ok := EventCreator[op.EventName]
@ -146,6 +149,11 @@ func HandleEvent(g *Gateway, data []byte) error {
return errors.Wrap(err, "Failed to parse event "+op.EventName) return errors.Wrap(err, "Failed to parse event "+op.EventName)
} }
// If the event is a ready, we'll want its sessionID
if ev, ok := ev.(*ReadyEvent); ok {
g.SessionID = ev.SessionID
}
// Throw the event into a channel, it's valid now. // Throw the event into a channel, it's valid now.
g.Events <- ev g.Events <- ev
return nil return nil

View file

@ -44,13 +44,24 @@ func (p *Pacemaker) Stop() {
// Start beats until it's dead. // Start beats until it's dead.
func (p *Pacemaker) Start() error { func (p *Pacemaker) Start() error {
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
stop := make(chan struct{}) stop := make(chan struct{})
p.stop = stop p.stop = stop
return p.start(stop)
}
func (p *Pacemaker) start(stop chan struct{}) error {
tick := time.NewTicker(p.Heartrate)
defer tick.Stop()
// Echo at least once
p.Echo()
for { for {
select {
case <-stop:
return nil
case <-tick.C:
if err := p.Pace(); err != nil { if err := p.Pace(); err != nil {
return err return err
} }
@ -60,20 +71,19 @@ func (p *Pacemaker) Start() error {
return err return err
} }
} }
select {
case <-stop:
return nil
case <-tick.C:
continue
} }
} }
} }
func (p *Pacemaker) StartAsync() (death <-chan error) { func (p *Pacemaker) StartAsync() (death <-chan error) {
var ch = make(chan error) var ch = make(chan error)
stop := make(chan struct{})
p.stop = stop
go func() { go func() {
ch <- p.Start() ch <- p.start(stop)
}() }()
return ch return ch
} }

View file

@ -54,6 +54,7 @@ func NewConn(driver json.Driver) *Conn {
return &Conn{ return &Conn{
Driver: driver, Driver: driver,
ReadTimeout: DefaultTimeout, ReadTimeout: DefaultTimeout,
events: make(chan Event, WSBuffer),
} }
} }
@ -67,12 +68,14 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
HTTPHeader: headers, HTTPHeader: headers,
}) })
go func() {
c.readLoop(c.events)
}()
return err return err
} }
func (c *Conn) Listen() <-chan Event { func (c *Conn) Listen() <-chan Event {
c.events = make(chan Event, WSBuffer)
go func() { c.readLoop(c.events) }()
return c.events return c.events
} }
@ -84,15 +87,19 @@ func (c *Conn) readLoop(ch chan Event) {
b, err := c.readAll(ctx) b, err := c.readAll(ctx)
if err != nil { if err != nil {
ch <- Event{nil, errors.Wrap(err, "WS error")}
// Check if the error is a fatal one // Check if the error is a fatal one
if websocket.CloseStatus(err) > -1 { if code := websocket.CloseStatus(err); code > -1 {
// Error is fatal, exit // Is the exit unusual?
if code != websocket.StatusNormalClosure {
// Unusual error, log
ch <- Event{nil, errors.Wrap(err, "WS fatal")}
}
return return
} }
// or it's not fatal, we just continue // or it's not fatal, we just log and continue
ch <- Event{nil, errors.Wrap(err, "WS error")}
continue continue
} }
@ -103,7 +110,7 @@ func (c *Conn) readLoop(ch chan Event) {
func (c *Conn) readAll(ctx context.Context) ([]byte, error) { func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
t, r, err := c.Reader(ctx) t, r, err := c.Reader(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "WS error") return nil, err
} }
if t == websocket.MessageBinary { if t == websocket.MessageBinary {
@ -139,9 +146,6 @@ func (c *Conn) Send(ctx context.Context, b []byte) error {
} }
func (c *Conn) Close(err error) error { func (c *Conn) Close(err error) error {
// Close the event channels
defer close(c.events)
if err == nil { if err == nil {
return c.Conn.Close(websocket.StatusNormalClosure, "") return c.Conn.Close(websocket.StatusNormalClosure, "")
} }

View file

@ -28,6 +28,7 @@ type Websocket struct {
DialLimiter *rate.Limiter DialLimiter *rate.Limiter
listener <-chan Event listener <-chan Event
dialed bool
} }
func New(ctx context.Context, addr string) (*Websocket, error) { func New(ctx context.Context, addr string) (*Websocket, error) {
@ -55,6 +56,12 @@ func (ws *Websocket) Redial(ctx context.Context) error {
return errors.Wrap(err, "Failed to wait") return errors.Wrap(err, "Failed to wait")
} }
// Close the connection
if ws.dialed {
ws.Conn.Close(nil)
}
ws.dialed = true
if err := ws.Conn.Dial(ctx, ws.Addr); err != nil { if err := ws.Conn.Dial(ctx, ws.Addr); err != nil {
return errors.Wrap(err, "Failed to dial") return errors.Wrap(err, "Failed to dial")
} }