WIP integration test

This commit is contained in:
diamondburned (Forefront) 2020-01-14 23:34:18 -08:00
parent 9f643fee7a
commit 03d226e23d
16 changed files with 218 additions and 47 deletions

View File

@ -402,11 +402,11 @@ func (c *Client) Integrations(
// AttachIntegration requires MANAGE_GUILD.
func (c *Client) AttachIntegration(
guildID, integrationID discord.Snowflake,
integrationType discord.IntegrationType) error {
integrationType discord.Service) error {
var param struct {
Type discord.IntegrationType `json:"type"`
ID discord.Snowflake `json:"id"`
Type discord.Service `json:"type"`
ID discord.Snowflake `json:"id"`
}
return c.FastRequest(

View File

@ -59,7 +59,7 @@ type Connection struct {
Visibility ConnectionVisibility `json:"visibility"`
// Only partial
Integratioons []Integration `json:"integrations"`
Integrations []Integration `json:"integrations"`
}
type ConnectionVisibility uint8

View File

@ -1,7 +1,10 @@
package gateway
import (
"context"
"github.com/diamondburned/arikawa/discord"
"github.com/pkg/errors"
)
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
@ -14,9 +17,16 @@ type IdentifyData struct {
LargeThreshold uint `json:"large_threshold,omitempty"` // 50
GuildSubscription bool `json:"guild_subscriptions"` // true
Shard [2]int `json:"shard"` // [ shard_id, num_shards ]
Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ]
Presence UpdateStatusData `json:"presence,omitempty"`
Presence *UpdateStatusData `json:"presence,omitempty"`
}
func (i *IdentifyData) SetShard(id, num int) {
if i.Shard == nil {
i.Shard = new(Shard)
}
i.Shard[0], i.Shard[1] = id, num
}
type IdentifyProperties struct {
@ -34,7 +44,14 @@ type IdentifyProperties struct {
}
func (g *Gateway) Identify() error {
return g.Send(IdentifyOP, g.Identity)
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
if err := g.Identifier.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for identify()")
}
return g.Send(IdentifyOP, g.Identifier)
}
type ResumeData struct {

View File

@ -19,7 +19,7 @@ type (
PrivateChannels []discord.Channel `json:"private_channels"`
Guilds []discord.Guild `json:"guilds"`
Shard [2]int `json:"shard"` // [ shard_id num_shards ]
Shard *Shard `json:"shard"`
}
ResumedEvent struct{}

View File

@ -9,6 +9,7 @@ package gateway
import (
"context"
"log"
"net/url"
"runtime"
"time"
@ -35,6 +36,11 @@ var (
// WSBuffer is the size of the Event channel. This has to be at least 1 to
// make space for the first Event: Ready or Resumed.
WSBuffer = 10
// WSRetries is the times Gateway would try and connect or reconnect to the
// gateway.
WSRetries = uint(5)
// WSError is the default error handler
WSError = func(err error) {}
)
var (
@ -75,9 +81,9 @@ type Gateway struct {
SessionID string
Identity *IdentifyData
Pacemaker *Pacemaker
Sequence Sequence
Identifier *Identifier
Pacemaker *Pacemaker
Sequence *Sequence
ErrorLog func(err error) // default to log.Println
@ -106,17 +112,13 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
}
g := &Gateway{
Driver: driver,
WSTimeout: WSTimeout,
Events: make(chan Event, WSBuffer),
Identity: &IdentifyData{
Token: token,
Properties: Identity,
Compress: true,
LargeThreshold: 50,
GuildSubscription: true,
},
Sequence: NewSequence(),
Driver: driver,
WSTimeout: WSTimeout,
WSRetries: WSRetries,
Events: make(chan Event, WSBuffer),
Identifier: DefaultIdentifier(token),
Sequence: NewSequence(),
ErrorLog: WSError,
}
// Parameters for the gateway
@ -166,7 +168,7 @@ func (g *Gateway) Resume() error {
}
return g.Send(ResumeOP, ResumeData{
Token: g.Identity.Token,
Token: g.Identifier.Token,
SessionID: ses,
Sequence: seq,
})
@ -181,7 +183,7 @@ func (g *Gateway) Start() error {
// Wait for an OP 10 Hello
var hello HelloEvent
if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
return errors.Wrap(err, "Error at Hello")
}
@ -195,9 +197,16 @@ func (g *Gateway) Start() error {
// We should now expect a Ready event.
var ready ReadyEvent
if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil {
p, err := AssertEvent(g, <-ch, DispatchOP, &ready)
if err != nil {
return errors.Wrap(err, "Error at Ready")
}
// We now also have the SessionID and the SequenceID
g.SessionID = ready.SessionID
g.Sequence.Set(p.Sequence)
// Send the event away
g.Events <- &ready
} else {
@ -207,9 +216,12 @@ func (g *Gateway) Start() error {
// We should now expect a Resumed event.
var resumed ResumedEvent
if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil {
_, err := AssertEvent(g, <-ch, DispatchOP, &resumed)
if err != nil {
return errors.Wrap(err, "Error at Resumed")
}
// Send the event away
g.Events <- &resumed
}
@ -245,7 +257,7 @@ func (g *Gateway) handleWS(stop <-chan struct{}) {
case ev := <-ch:
// Check for error
if ev.Error != nil {
g.ErrorLog(errors.Wrap(ev.Error, "WS error"))
g.ErrorLog(ev.Error)
continue
}
@ -263,7 +275,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
}
if v != nil {
b, err := g.Marshal(v)
b, err := g.Driver.Marshal(v)
if err != nil {
return errors.Wrap(err, "Failed to encode v")
}
@ -271,11 +283,13 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
op.Data = b
}
b, err := g.Marshal(op)
b, err := g.Driver.Marshal(op)
if err != nil {
return errors.Wrap(err, "Failed to encode payload")
}
log.Println("->", len(b), string(b))
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()

46
gateway/identify.go Normal file
View File

@ -0,0 +1,46 @@
package gateway
import (
"context"
"time"
"github.com/pkg/errors"
"golang.org/x/time/rate"
)
type Identifier struct {
IdentifyData
IdentifyShortLimit *rate.Limiter `json:"-"`
IdentifyGlobalLimit *rate.Limiter `json:"-"`
}
func DefaultIdentifier(token string) *Identifier {
return NewIdentifier(IdentifyData{
Token: token,
Properties: Identity,
Shard: DefaultShard(),
Compress: true,
LargeThreshold: 50,
GuildSubscription: true,
})
}
func NewIdentifier(data IdentifyData) *Identifier {
return &Identifier{
IdentifyData: data,
IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1),
IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000),
}
}
func (i *Identifier) Wait(ctx context.Context) error {
if err := i.IdentifyShortLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for short limit")
}
if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil {
return errors.Wrap(err, "Can't wait for global limit")
}
return nil
}

View File

@ -0,0 +1,50 @@
// +build integration
package gateway
import (
"log"
"os"
"testing"
)
func TestIntegration(t *testing.T) {
var token = os.Getenv("BOT_TOKEN")
if token == "" {
t.Fatal("Missing $BOT_TOKEN")
}
WSError = func(err error) {
t.Error("WS:", err)
}
var gateway *Gateway
// NewGateway should call Start for us.
g, err := NewGateway(token)
if err != nil {
t.Fatal("Failed to make a Gateway:", err)
}
gateway = g
ready, ok := (<-gateway.Events).(*ReadyEvent)
if !ok {
t.Fatal("Event received is not of type Ready:", ready)
}
if gateway.SessionID == "" {
t.Fatal("Session ID is empty")
}
log.Println("Bot's username is", ready.User.Username)
// Try and reconnect
if err := gateway.Reconnect(); err != nil {
t.Fatal("Failed to reconnect:", err)
}
resumed, ok := (<-gateway.Events).(*ResumedEvent)
if !ok {
t.Fatal("Event received is not of type Resumed:", resumed)
}
}

View File

@ -2,6 +2,7 @@ package gateway
import (
"fmt"
"log"
"github.com/diamondburned/arikawa/json"
"github.com/diamondburned/arikawa/wsutil"
@ -33,7 +34,7 @@ type OP struct {
Data json.Raw `json:"d,omitempty"`
// Only for Dispatch (op 0)
Sequence int `json:"s,omitempty"`
Sequence int64 `json:"s,omitempty"`
EventName string `json:"t,omitempty"`
}
@ -44,6 +45,8 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
return nil, ev.Error
}
log.Println("<-", string(ev.Data))
var op *OP
if err := driver.Unmarshal(ev.Data, &op); err != nil {
return nil, errors.Wrap(err, "Failed to decode payload")
@ -72,25 +75,25 @@ func DecodeEvent(driver json.Driver,
}
func AssertEvent(driver json.Driver,
ev wsutil.Event, code OPCode, v interface{}) error {
ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
op, err := DecodeOP(driver, ev)
if err != nil {
return err
return nil, err
}
if op.Code != code {
return fmt.Errorf(
return op, fmt.Errorf(
"Unexpected OP Code: %d, expected %d (%s)",
op.Code, code, op.Data,
)
}
if err := driver.Unmarshal(op.Data, v); err != nil {
return errors.Wrap(err, "Failed to decode data")
return op, errors.Wrap(err, "Failed to decode data")
}
return nil
return op, nil
}
func HandleEvent(g *Gateway, data []byte) error {
@ -126,6 +129,9 @@ func HandleEvent(g *Gateway, data []byte) error {
return nil
case DispatchOP:
// Set the sequence
g.Sequence.Set(op.Sequence)
// Check if we know the event
fn, ok := EventCreator[op.EventName]
if !ok {

View File

@ -55,11 +55,10 @@ func (p *Pacemaker) Start() error {
return err
}
if !p.Dead() {
continue
}
if err := p.OnDead(); err != nil {
return err
if p.Dead() {
if err := p.OnDead(); err != nil {
return err
}
}
select {

View File

@ -6,8 +6,8 @@ type Sequence struct {
seq int64
}
func NewSequence() Sequence {
return Sequence{0}
func NewSequence() *Sequence {
return &Sequence{0}
}
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) }

16
gateway/shards.go Normal file
View File

@ -0,0 +1,16 @@
package gateway
type Shard [2]int
func DefaultShard() *Shard {
var s = Shard([2]int{0, 1})
return &s
}
func (s Shard) ShardID() int {
return s[0]
}
func (s Shard) NumShards() int {
return s[1]
}

2
go.mod
View File

@ -6,6 +6,8 @@ require (
github.com/bwmarrin/discordgo v0.20.2
github.com/gorilla/schema v1.1.0
github.com/gorilla/websocket v1.4.1
github.com/k0kubun/pp v3.0.1+incompatible
github.com/mattn/go-colorable v0.1.4 // indirect
github.com/pkg/errors v0.8.1
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
go.uber.org/atomic v1.4.0

8
go.sum
View File

@ -18,9 +18,15 @@ github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40=
github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -38,6 +44,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=

View File

@ -55,6 +55,11 @@ func (m Raw) MarshalJSON() ([]byte, error) {
return m, nil
}
func (m *Raw) UnmarshalJSON(data []byte) error {
*m = append((*m)[0:0], data...)
return nil
}
func (m Raw) String() string {
return string(m)
}

View File

@ -77,8 +77,6 @@ func (c *Conn) Listen() <-chan Event {
}
func (c *Conn) readLoop(ch chan Event) {
defer close(ch)
for {
ctx, cancel := context.WithTimeout(
context.Background(), c.ReadTimeout)
@ -124,15 +122,17 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
}
func (c *Conn) Send(ctx context.Context, b []byte) error {
w, err := c.Writer(ctx, websocket.MessageBinary)
// TODO: zlib stream
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return errors.Wrap(err, "Failed to get WS writer")
}
defer w.Close()
// Compress with zlib by default.
w = zlib.NewWriter(w)
// Compress with zlib by default NOT.
// w = zlib.NewWriter(w)
_, err = w.Write(b)
return err

View File

@ -13,3 +13,11 @@ func NewSendLimiter() *rate.Limiter {
func NewDialLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Every(5*time.Second), 1)
}
func NewIdentityLimiter() *rate.Limiter {
return NewDialLimiter() // same
}
func NewGlobalIdentityLimiter() *rate.Limiter {
return rate.NewLimiter(rate.Every(24*time.Hour), 1000)
}