From 03d226e23d6fe455947e923f74402da58e53364c Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Tue, 14 Jan 2020 23:34:18 -0800 Subject: [PATCH] WIP integration test --- api/guild.go | 6 ++-- discord/user.go | 2 +- gateway/commands.go | 23 +++++++++++++-- gateway/events.go | 2 +- gateway/gateway.go | 56 +++++++++++++++++++++++-------------- gateway/identify.go | 46 ++++++++++++++++++++++++++++++ gateway/integration_test.go | 50 +++++++++++++++++++++++++++++++++ gateway/op.go | 18 ++++++++---- gateway/pacemaker.go | 9 +++--- gateway/sequence.go | 4 +-- gateway/shards.go | 16 +++++++++++ go.mod | 2 ++ go.sum | 8 ++++++ json/json.go | 5 ++++ wsutil/conn.go | 10 +++---- wsutil/throttler.go | 8 ++++++ 16 files changed, 218 insertions(+), 47 deletions(-) create mode 100644 gateway/identify.go create mode 100644 gateway/integration_test.go create mode 100644 gateway/shards.go diff --git a/api/guild.go b/api/guild.go index 89c9646..02a620b 100644 --- a/api/guild.go +++ b/api/guild.go @@ -402,11 +402,11 @@ func (c *Client) Integrations( // AttachIntegration requires MANAGE_GUILD. func (c *Client) AttachIntegration( guildID, integrationID discord.Snowflake, - integrationType discord.IntegrationType) error { + integrationType discord.Service) error { var param struct { - Type discord.IntegrationType `json:"type"` - ID discord.Snowflake `json:"id"` + Type discord.Service `json:"type"` + ID discord.Snowflake `json:"id"` } return c.FastRequest( diff --git a/discord/user.go b/discord/user.go index 35611f8..c1ee290 100644 --- a/discord/user.go +++ b/discord/user.go @@ -59,7 +59,7 @@ type Connection struct { Visibility ConnectionVisibility `json:"visibility"` // Only partial - Integratioons []Integration `json:"integrations"` + Integrations []Integration `json:"integrations"` } type ConnectionVisibility uint8 diff --git a/gateway/commands.go b/gateway/commands.go index 9d24121..5b851c1 100644 --- a/gateway/commands.go +++ b/gateway/commands.go @@ -1,7 +1,10 @@ package gateway import ( + "context" + "github.com/diamondburned/arikawa/discord" + "github.com/pkg/errors" ) // Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent @@ -14,9 +17,16 @@ type IdentifyData struct { LargeThreshold uint `json:"large_threshold,omitempty"` // 50 GuildSubscription bool `json:"guild_subscriptions"` // true - Shard [2]int `json:"shard"` // [ shard_id, num_shards ] + Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ] - Presence UpdateStatusData `json:"presence,omitempty"` + Presence *UpdateStatusData `json:"presence,omitempty"` +} + +func (i *IdentifyData) SetShard(id, num int) { + if i.Shard == nil { + i.Shard = new(Shard) + } + i.Shard[0], i.Shard[1] = id, num } type IdentifyProperties struct { @@ -34,7 +44,14 @@ type IdentifyProperties struct { } func (g *Gateway) Identify() error { - return g.Send(IdentifyOP, g.Identity) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) + defer cancel() + + if err := g.Identifier.Wait(ctx); err != nil { + return errors.Wrap(err, "Can't wait for identify()") + } + + return g.Send(IdentifyOP, g.Identifier) } type ResumeData struct { diff --git a/gateway/events.go b/gateway/events.go index 4539ac8..99f8c9c 100644 --- a/gateway/events.go +++ b/gateway/events.go @@ -19,7 +19,7 @@ type ( PrivateChannels []discord.Channel `json:"private_channels"` Guilds []discord.Guild `json:"guilds"` - Shard [2]int `json:"shard"` // [ shard_id num_shards ] + Shard *Shard `json:"shard"` } ResumedEvent struct{} diff --git a/gateway/gateway.go b/gateway/gateway.go index 31fec58..e51669b 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -9,6 +9,7 @@ package gateway import ( "context" + "log" "net/url" "runtime" "time" @@ -35,6 +36,11 @@ var ( // WSBuffer is the size of the Event channel. This has to be at least 1 to // make space for the first Event: Ready or Resumed. WSBuffer = 10 + // WSRetries is the times Gateway would try and connect or reconnect to the + // gateway. + WSRetries = uint(5) + // WSError is the default error handler + WSError = func(err error) {} ) var ( @@ -75,9 +81,9 @@ type Gateway struct { SessionID string - Identity *IdentifyData - Pacemaker *Pacemaker - Sequence Sequence + Identifier *Identifier + Pacemaker *Pacemaker + Sequence *Sequence ErrorLog func(err error) // default to log.Println @@ -106,17 +112,13 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) { } g := &Gateway{ - Driver: driver, - WSTimeout: WSTimeout, - Events: make(chan Event, WSBuffer), - Identity: &IdentifyData{ - Token: token, - Properties: Identity, - Compress: true, - LargeThreshold: 50, - GuildSubscription: true, - }, - Sequence: NewSequence(), + Driver: driver, + WSTimeout: WSTimeout, + WSRetries: WSRetries, + Events: make(chan Event, WSBuffer), + Identifier: DefaultIdentifier(token), + Sequence: NewSequence(), + ErrorLog: WSError, } // Parameters for the gateway @@ -166,7 +168,7 @@ func (g *Gateway) Resume() error { } return g.Send(ResumeOP, ResumeData{ - Token: g.Identity.Token, + Token: g.Identifier.Token, SessionID: ses, Sequence: seq, }) @@ -181,7 +183,7 @@ func (g *Gateway) Start() error { // Wait for an OP 10 Hello var hello HelloEvent - if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil { + if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil { return errors.Wrap(err, "Error at Hello") } @@ -195,9 +197,16 @@ func (g *Gateway) Start() error { // We should now expect a Ready event. var ready ReadyEvent - if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil { + p, err := AssertEvent(g, <-ch, DispatchOP, &ready) + if err != nil { return errors.Wrap(err, "Error at Ready") } + + // We now also have the SessionID and the SequenceID + g.SessionID = ready.SessionID + g.Sequence.Set(p.Sequence) + + // Send the event away g.Events <- &ready } else { @@ -207,9 +216,12 @@ func (g *Gateway) Start() error { // We should now expect a Resumed event. var resumed ResumedEvent - if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil { + _, err := AssertEvent(g, <-ch, DispatchOP, &resumed) + if err != nil { return errors.Wrap(err, "Error at Resumed") } + + // Send the event away g.Events <- &resumed } @@ -245,7 +257,7 @@ func (g *Gateway) handleWS(stop <-chan struct{}) { case ev := <-ch: // Check for error if ev.Error != nil { - g.ErrorLog(errors.Wrap(ev.Error, "WS error")) + g.ErrorLog(ev.Error) continue } @@ -263,7 +275,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { } if v != nil { - b, err := g.Marshal(v) + b, err := g.Driver.Marshal(v) if err != nil { return errors.Wrap(err, "Failed to encode v") } @@ -271,11 +283,13 @@ func (g *Gateway) Send(code OPCode, v interface{}) error { op.Data = b } - b, err := g.Marshal(op) + b, err := g.Driver.Marshal(op) if err != nil { return errors.Wrap(err, "Failed to encode payload") } + log.Println("->", len(b), string(b)) + ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) defer cancel() diff --git a/gateway/identify.go b/gateway/identify.go new file mode 100644 index 0000000..2a6e71d --- /dev/null +++ b/gateway/identify.go @@ -0,0 +1,46 @@ +package gateway + +import ( + "context" + "time" + + "github.com/pkg/errors" + "golang.org/x/time/rate" +) + +type Identifier struct { + IdentifyData + + IdentifyShortLimit *rate.Limiter `json:"-"` + IdentifyGlobalLimit *rate.Limiter `json:"-"` +} + +func DefaultIdentifier(token string) *Identifier { + return NewIdentifier(IdentifyData{ + Token: token, + Properties: Identity, + Shard: DefaultShard(), + + Compress: true, + LargeThreshold: 50, + GuildSubscription: true, + }) +} + +func NewIdentifier(data IdentifyData) *Identifier { + return &Identifier{ + IdentifyData: data, + IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1), + IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000), + } +} + +func (i *Identifier) Wait(ctx context.Context) error { + if err := i.IdentifyShortLimit.Wait(ctx); err != nil { + return errors.Wrap(err, "Can't wait for short limit") + } + if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil { + return errors.Wrap(err, "Can't wait for global limit") + } + return nil +} diff --git a/gateway/integration_test.go b/gateway/integration_test.go new file mode 100644 index 0000000..0b59cbd --- /dev/null +++ b/gateway/integration_test.go @@ -0,0 +1,50 @@ +// +build integration + +package gateway + +import ( + "log" + "os" + "testing" +) + +func TestIntegration(t *testing.T) { + var token = os.Getenv("BOT_TOKEN") + if token == "" { + t.Fatal("Missing $BOT_TOKEN") + } + + WSError = func(err error) { + t.Error("WS:", err) + } + + var gateway *Gateway + + // NewGateway should call Start for us. + g, err := NewGateway(token) + if err != nil { + t.Fatal("Failed to make a Gateway:", err) + } + gateway = g + + ready, ok := (<-gateway.Events).(*ReadyEvent) + if !ok { + t.Fatal("Event received is not of type Ready:", ready) + } + + if gateway.SessionID == "" { + t.Fatal("Session ID is empty") + } + + log.Println("Bot's username is", ready.User.Username) + + // Try and reconnect + if err := gateway.Reconnect(); err != nil { + t.Fatal("Failed to reconnect:", err) + } + + resumed, ok := (<-gateway.Events).(*ResumedEvent) + if !ok { + t.Fatal("Event received is not of type Resumed:", resumed) + } +} diff --git a/gateway/op.go b/gateway/op.go index 7732440..3a0a26a 100644 --- a/gateway/op.go +++ b/gateway/op.go @@ -2,6 +2,7 @@ package gateway import ( "fmt" + "log" "github.com/diamondburned/arikawa/json" "github.com/diamondburned/arikawa/wsutil" @@ -33,7 +34,7 @@ type OP struct { Data json.Raw `json:"d,omitempty"` // Only for Dispatch (op 0) - Sequence int `json:"s,omitempty"` + Sequence int64 `json:"s,omitempty"` EventName string `json:"t,omitempty"` } @@ -44,6 +45,8 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) { return nil, ev.Error } + log.Println("<-", string(ev.Data)) + var op *OP if err := driver.Unmarshal(ev.Data, &op); err != nil { return nil, errors.Wrap(err, "Failed to decode payload") @@ -72,25 +75,25 @@ func DecodeEvent(driver json.Driver, } func AssertEvent(driver json.Driver, - ev wsutil.Event, code OPCode, v interface{}) error { + ev wsutil.Event, code OPCode, v interface{}) (*OP, error) { op, err := DecodeOP(driver, ev) if err != nil { - return err + return nil, err } if op.Code != code { - return fmt.Errorf( + return op, fmt.Errorf( "Unexpected OP Code: %d, expected %d (%s)", op.Code, code, op.Data, ) } if err := driver.Unmarshal(op.Data, v); err != nil { - return errors.Wrap(err, "Failed to decode data") + return op, errors.Wrap(err, "Failed to decode data") } - return nil + return op, nil } func HandleEvent(g *Gateway, data []byte) error { @@ -126,6 +129,9 @@ func HandleEvent(g *Gateway, data []byte) error { return nil case DispatchOP: + // Set the sequence + g.Sequence.Set(op.Sequence) + // Check if we know the event fn, ok := EventCreator[op.EventName] if !ok { diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index cef57c8..fccfdb0 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -55,11 +55,10 @@ func (p *Pacemaker) Start() error { return err } - if !p.Dead() { - continue - } - if err := p.OnDead(); err != nil { - return err + if p.Dead() { + if err := p.OnDead(); err != nil { + return err + } } select { diff --git a/gateway/sequence.go b/gateway/sequence.go index 15e66d8..59b5bc5 100644 --- a/gateway/sequence.go +++ b/gateway/sequence.go @@ -6,8 +6,8 @@ type Sequence struct { seq int64 } -func NewSequence() Sequence { - return Sequence{0} +func NewSequence() *Sequence { + return &Sequence{0} } func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) } diff --git a/gateway/shards.go b/gateway/shards.go new file mode 100644 index 0000000..ddc662f --- /dev/null +++ b/gateway/shards.go @@ -0,0 +1,16 @@ +package gateway + +type Shard [2]int + +func DefaultShard() *Shard { + var s = Shard([2]int{0, 1}) + return &s +} + +func (s Shard) ShardID() int { + return s[0] +} + +func (s Shard) NumShards() int { + return s[1] +} diff --git a/go.mod b/go.mod index beab7fa..a5cc4b7 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/bwmarrin/discordgo v0.20.2 github.com/gorilla/schema v1.1.0 github.com/gorilla/websocket v1.4.1 + github.com/k0kubun/pp v3.0.1+incompatible + github.com/mattn/go-colorable v0.1.4 // indirect github.com/pkg/errors v0.8.1 github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa go.uber.org/atomic v1.4.0 diff --git a/go.sum b/go.sum index 4c0db88..eb93d45 100644 --- a/go.sum +++ b/go.sum @@ -18,9 +18,15 @@ github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40= +github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -38,6 +44,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8= golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= diff --git a/json/json.go b/json/json.go index a0941c9..081399e 100644 --- a/json/json.go +++ b/json/json.go @@ -55,6 +55,11 @@ func (m Raw) MarshalJSON() ([]byte, error) { return m, nil } +func (m *Raw) UnmarshalJSON(data []byte) error { + *m = append((*m)[0:0], data...) + return nil +} + func (m Raw) String() string { return string(m) } diff --git a/wsutil/conn.go b/wsutil/conn.go index 35d7499..2d3ac6d 100644 --- a/wsutil/conn.go +++ b/wsutil/conn.go @@ -77,8 +77,6 @@ func (c *Conn) Listen() <-chan Event { } func (c *Conn) readLoop(ch chan Event) { - defer close(ch) - for { ctx, cancel := context.WithTimeout( context.Background(), c.ReadTimeout) @@ -124,15 +122,17 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { } func (c *Conn) Send(ctx context.Context, b []byte) error { - w, err := c.Writer(ctx, websocket.MessageBinary) + // TODO: zlib stream + + w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return errors.Wrap(err, "Failed to get WS writer") } defer w.Close() - // Compress with zlib by default. - w = zlib.NewWriter(w) + // Compress with zlib by default NOT. + // w = zlib.NewWriter(w) _, err = w.Write(b) return err diff --git a/wsutil/throttler.go b/wsutil/throttler.go index f586b5f..936553c 100644 --- a/wsutil/throttler.go +++ b/wsutil/throttler.go @@ -13,3 +13,11 @@ func NewSendLimiter() *rate.Limiter { func NewDialLimiter() *rate.Limiter { return rate.NewLimiter(rate.Every(5*time.Second), 1) } + +func NewIdentityLimiter() *rate.Limiter { + return NewDialLimiter() // same +} + +func NewGlobalIdentityLimiter() *rate.Limiter { + return rate.NewLimiter(rate.Every(24*time.Hour), 1000) +}