mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-21 17:39:25 +00:00
WIP integration test
This commit is contained in:
parent
9f643fee7a
commit
03d226e23d
|
@ -402,11 +402,11 @@ func (c *Client) Integrations(
|
|||
// AttachIntegration requires MANAGE_GUILD.
|
||||
func (c *Client) AttachIntegration(
|
||||
guildID, integrationID discord.Snowflake,
|
||||
integrationType discord.IntegrationType) error {
|
||||
integrationType discord.Service) error {
|
||||
|
||||
var param struct {
|
||||
Type discord.IntegrationType `json:"type"`
|
||||
ID discord.Snowflake `json:"id"`
|
||||
Type discord.Service `json:"type"`
|
||||
ID discord.Snowflake `json:"id"`
|
||||
}
|
||||
|
||||
return c.FastRequest(
|
||||
|
|
|
@ -59,7 +59,7 @@ type Connection struct {
|
|||
Visibility ConnectionVisibility `json:"visibility"`
|
||||
|
||||
// Only partial
|
||||
Integratioons []Integration `json:"integrations"`
|
||||
Integrations []Integration `json:"integrations"`
|
||||
}
|
||||
|
||||
type ConnectionVisibility uint8
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Rules: VOICE_STATE_UPDATE -> VoiceStateUpdateEvent
|
||||
|
@ -14,9 +17,16 @@ type IdentifyData struct {
|
|||
LargeThreshold uint `json:"large_threshold,omitempty"` // 50
|
||||
GuildSubscription bool `json:"guild_subscriptions"` // true
|
||||
|
||||
Shard [2]int `json:"shard"` // [ shard_id, num_shards ]
|
||||
Shard *Shard `json:"shard,omitempty"` // [ shard_id, num_shards ]
|
||||
|
||||
Presence UpdateStatusData `json:"presence,omitempty"`
|
||||
Presence *UpdateStatusData `json:"presence,omitempty"`
|
||||
}
|
||||
|
||||
func (i *IdentifyData) SetShard(id, num int) {
|
||||
if i.Shard == nil {
|
||||
i.Shard = new(Shard)
|
||||
}
|
||||
i.Shard[0], i.Shard[1] = id, num
|
||||
}
|
||||
|
||||
type IdentifyProperties struct {
|
||||
|
@ -34,7 +44,14 @@ type IdentifyProperties struct {
|
|||
}
|
||||
|
||||
func (g *Gateway) Identify() error {
|
||||
return g.Send(IdentifyOP, g.Identity)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
if err := g.Identifier.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "Can't wait for identify()")
|
||||
}
|
||||
|
||||
return g.Send(IdentifyOP, g.Identifier)
|
||||
}
|
||||
|
||||
type ResumeData struct {
|
||||
|
|
|
@ -19,7 +19,7 @@ type (
|
|||
PrivateChannels []discord.Channel `json:"private_channels"`
|
||||
Guilds []discord.Guild `json:"guilds"`
|
||||
|
||||
Shard [2]int `json:"shard"` // [ shard_id num_shards ]
|
||||
Shard *Shard `json:"shard"`
|
||||
}
|
||||
|
||||
ResumedEvent struct{}
|
||||
|
|
|
@ -9,6 +9,7 @@ package gateway
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"net/url"
|
||||
"runtime"
|
||||
"time"
|
||||
|
@ -35,6 +36,11 @@ var (
|
|||
// WSBuffer is the size of the Event channel. This has to be at least 1 to
|
||||
// make space for the first Event: Ready or Resumed.
|
||||
WSBuffer = 10
|
||||
// WSRetries is the times Gateway would try and connect or reconnect to the
|
||||
// gateway.
|
||||
WSRetries = uint(5)
|
||||
// WSError is the default error handler
|
||||
WSError = func(err error) {}
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -75,9 +81,9 @@ type Gateway struct {
|
|||
|
||||
SessionID string
|
||||
|
||||
Identity *IdentifyData
|
||||
Pacemaker *Pacemaker
|
||||
Sequence Sequence
|
||||
Identifier *Identifier
|
||||
Pacemaker *Pacemaker
|
||||
Sequence *Sequence
|
||||
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
|
@ -106,17 +112,13 @@ func NewGatewayWithDriver(token string, driver json.Driver) (*Gateway, error) {
|
|||
}
|
||||
|
||||
g := &Gateway{
|
||||
Driver: driver,
|
||||
WSTimeout: WSTimeout,
|
||||
Events: make(chan Event, WSBuffer),
|
||||
Identity: &IdentifyData{
|
||||
Token: token,
|
||||
Properties: Identity,
|
||||
Compress: true,
|
||||
LargeThreshold: 50,
|
||||
GuildSubscription: true,
|
||||
},
|
||||
Sequence: NewSequence(),
|
||||
Driver: driver,
|
||||
WSTimeout: WSTimeout,
|
||||
WSRetries: WSRetries,
|
||||
Events: make(chan Event, WSBuffer),
|
||||
Identifier: DefaultIdentifier(token),
|
||||
Sequence: NewSequence(),
|
||||
ErrorLog: WSError,
|
||||
}
|
||||
|
||||
// Parameters for the gateway
|
||||
|
@ -166,7 +168,7 @@ func (g *Gateway) Resume() error {
|
|||
}
|
||||
|
||||
return g.Send(ResumeOP, ResumeData{
|
||||
Token: g.Identity.Token,
|
||||
Token: g.Identifier.Token,
|
||||
SessionID: ses,
|
||||
Sequence: seq,
|
||||
})
|
||||
|
@ -181,7 +183,7 @@ func (g *Gateway) Start() error {
|
|||
|
||||
// Wait for an OP 10 Hello
|
||||
var hello HelloEvent
|
||||
if err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
||||
if _, err := AssertEvent(g, <-ch, HelloOP, &hello); err != nil {
|
||||
return errors.Wrap(err, "Error at Hello")
|
||||
}
|
||||
|
||||
|
@ -195,9 +197,16 @@ func (g *Gateway) Start() error {
|
|||
|
||||
// We should now expect a Ready event.
|
||||
var ready ReadyEvent
|
||||
if err := AssertEvent(g, <-ch, DispatchOP, &ready); err != nil {
|
||||
p, err := AssertEvent(g, <-ch, DispatchOP, &ready)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error at Ready")
|
||||
}
|
||||
|
||||
// We now also have the SessionID and the SequenceID
|
||||
g.SessionID = ready.SessionID
|
||||
g.Sequence.Set(p.Sequence)
|
||||
|
||||
// Send the event away
|
||||
g.Events <- &ready
|
||||
|
||||
} else {
|
||||
|
@ -207,9 +216,12 @@ func (g *Gateway) Start() error {
|
|||
|
||||
// We should now expect a Resumed event.
|
||||
var resumed ResumedEvent
|
||||
if err := AssertEvent(g, <-ch, DispatchOP, &resumed); err != nil {
|
||||
_, err := AssertEvent(g, <-ch, DispatchOP, &resumed)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error at Resumed")
|
||||
}
|
||||
|
||||
// Send the event away
|
||||
g.Events <- &resumed
|
||||
}
|
||||
|
||||
|
@ -245,7 +257,7 @@ func (g *Gateway) handleWS(stop <-chan struct{}) {
|
|||
case ev := <-ch:
|
||||
// Check for error
|
||||
if ev.Error != nil {
|
||||
g.ErrorLog(errors.Wrap(ev.Error, "WS error"))
|
||||
g.ErrorLog(ev.Error)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -263,7 +275,7 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
}
|
||||
|
||||
if v != nil {
|
||||
b, err := g.Marshal(v)
|
||||
b, err := g.Driver.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode v")
|
||||
}
|
||||
|
@ -271,11 +283,13 @@ func (g *Gateway) Send(code OPCode, v interface{}) error {
|
|||
op.Data = b
|
||||
}
|
||||
|
||||
b, err := g.Marshal(op)
|
||||
b, err := g.Driver.Marshal(op)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to encode payload")
|
||||
}
|
||||
|
||||
log.Println("->", len(b), string(b))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
|
46
gateway/identify.go
Normal file
46
gateway/identify.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type Identifier struct {
|
||||
IdentifyData
|
||||
|
||||
IdentifyShortLimit *rate.Limiter `json:"-"`
|
||||
IdentifyGlobalLimit *rate.Limiter `json:"-"`
|
||||
}
|
||||
|
||||
func DefaultIdentifier(token string) *Identifier {
|
||||
return NewIdentifier(IdentifyData{
|
||||
Token: token,
|
||||
Properties: Identity,
|
||||
Shard: DefaultShard(),
|
||||
|
||||
Compress: true,
|
||||
LargeThreshold: 50,
|
||||
GuildSubscription: true,
|
||||
})
|
||||
}
|
||||
|
||||
func NewIdentifier(data IdentifyData) *Identifier {
|
||||
return &Identifier{
|
||||
IdentifyData: data,
|
||||
IdentifyShortLimit: rate.NewLimiter(rate.Every(5*time.Second), 1),
|
||||
IdentifyGlobalLimit: rate.NewLimiter(rate.Every(24*time.Hour), 1000),
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Identifier) Wait(ctx context.Context) error {
|
||||
if err := i.IdentifyShortLimit.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "Can't wait for short limit")
|
||||
}
|
||||
if err := i.IdentifyGlobalLimit.Wait(ctx); err != nil {
|
||||
return errors.Wrap(err, "Can't wait for global limit")
|
||||
}
|
||||
return nil
|
||||
}
|
50
gateway/integration_test.go
Normal file
50
gateway/integration_test.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
// +build integration
|
||||
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIntegration(t *testing.T) {
|
||||
var token = os.Getenv("BOT_TOKEN")
|
||||
if token == "" {
|
||||
t.Fatal("Missing $BOT_TOKEN")
|
||||
}
|
||||
|
||||
WSError = func(err error) {
|
||||
t.Error("WS:", err)
|
||||
}
|
||||
|
||||
var gateway *Gateway
|
||||
|
||||
// NewGateway should call Start for us.
|
||||
g, err := NewGateway(token)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to make a Gateway:", err)
|
||||
}
|
||||
gateway = g
|
||||
|
||||
ready, ok := (<-gateway.Events).(*ReadyEvent)
|
||||
if !ok {
|
||||
t.Fatal("Event received is not of type Ready:", ready)
|
||||
}
|
||||
|
||||
if gateway.SessionID == "" {
|
||||
t.Fatal("Session ID is empty")
|
||||
}
|
||||
|
||||
log.Println("Bot's username is", ready.User.Username)
|
||||
|
||||
// Try and reconnect
|
||||
if err := gateway.Reconnect(); err != nil {
|
||||
t.Fatal("Failed to reconnect:", err)
|
||||
}
|
||||
|
||||
resumed, ok := (<-gateway.Events).(*ResumedEvent)
|
||||
if !ok {
|
||||
t.Fatal("Event received is not of type Resumed:", resumed)
|
||||
}
|
||||
}
|
|
@ -2,6 +2,7 @@ package gateway
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/diamondburned/arikawa/json"
|
||||
"github.com/diamondburned/arikawa/wsutil"
|
||||
|
@ -33,7 +34,7 @@ type OP struct {
|
|||
Data json.Raw `json:"d,omitempty"`
|
||||
|
||||
// Only for Dispatch (op 0)
|
||||
Sequence int `json:"s,omitempty"`
|
||||
Sequence int64 `json:"s,omitempty"`
|
||||
EventName string `json:"t,omitempty"`
|
||||
}
|
||||
|
||||
|
@ -44,6 +45,8 @@ func DecodeOP(driver json.Driver, ev wsutil.Event) (*OP, error) {
|
|||
return nil, ev.Error
|
||||
}
|
||||
|
||||
log.Println("<-", string(ev.Data))
|
||||
|
||||
var op *OP
|
||||
if err := driver.Unmarshal(ev.Data, &op); err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to decode payload")
|
||||
|
@ -72,25 +75,25 @@ func DecodeEvent(driver json.Driver,
|
|||
}
|
||||
|
||||
func AssertEvent(driver json.Driver,
|
||||
ev wsutil.Event, code OPCode, v interface{}) error {
|
||||
ev wsutil.Event, code OPCode, v interface{}) (*OP, error) {
|
||||
|
||||
op, err := DecodeOP(driver, ev)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if op.Code != code {
|
||||
return fmt.Errorf(
|
||||
return op, fmt.Errorf(
|
||||
"Unexpected OP Code: %d, expected %d (%s)",
|
||||
op.Code, code, op.Data,
|
||||
)
|
||||
}
|
||||
|
||||
if err := driver.Unmarshal(op.Data, v); err != nil {
|
||||
return errors.Wrap(err, "Failed to decode data")
|
||||
return op, errors.Wrap(err, "Failed to decode data")
|
||||
}
|
||||
|
||||
return nil
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func HandleEvent(g *Gateway, data []byte) error {
|
||||
|
@ -126,6 +129,9 @@ func HandleEvent(g *Gateway, data []byte) error {
|
|||
return nil
|
||||
|
||||
case DispatchOP:
|
||||
// Set the sequence
|
||||
g.Sequence.Set(op.Sequence)
|
||||
|
||||
// Check if we know the event
|
||||
fn, ok := EventCreator[op.EventName]
|
||||
if !ok {
|
||||
|
|
|
@ -55,11 +55,10 @@ func (p *Pacemaker) Start() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if !p.Dead() {
|
||||
continue
|
||||
}
|
||||
if err := p.OnDead(); err != nil {
|
||||
return err
|
||||
if p.Dead() {
|
||||
if err := p.OnDead(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
|
|
|
@ -6,8 +6,8 @@ type Sequence struct {
|
|||
seq int64
|
||||
}
|
||||
|
||||
func NewSequence() Sequence {
|
||||
return Sequence{0}
|
||||
func NewSequence() *Sequence {
|
||||
return &Sequence{0}
|
||||
}
|
||||
|
||||
func (s *Sequence) Set(seq int64) { atomic.StoreInt64(&s.seq, seq) }
|
||||
|
|
16
gateway/shards.go
Normal file
16
gateway/shards.go
Normal file
|
@ -0,0 +1,16 @@
|
|||
package gateway
|
||||
|
||||
type Shard [2]int
|
||||
|
||||
func DefaultShard() *Shard {
|
||||
var s = Shard([2]int{0, 1})
|
||||
return &s
|
||||
}
|
||||
|
||||
func (s Shard) ShardID() int {
|
||||
return s[0]
|
||||
}
|
||||
|
||||
func (s Shard) NumShards() int {
|
||||
return s[1]
|
||||
}
|
2
go.mod
2
go.mod
|
@ -6,6 +6,8 @@ require (
|
|||
github.com/bwmarrin/discordgo v0.20.2
|
||||
github.com/gorilla/schema v1.1.0
|
||||
github.com/gorilla/websocket v1.4.1
|
||||
github.com/k0kubun/pp v3.0.1+incompatible
|
||||
github.com/mattn/go-colorable v0.1.4 // indirect
|
||||
github.com/pkg/errors v0.8.1
|
||||
github.com/sasha-s/go-csync v0.0.0-20160729053059-3bc6c8bdb3fa
|
||||
go.uber.org/atomic v1.4.0
|
||||
|
|
8
go.sum
8
go.sum
|
@ -18,9 +18,15 @@ github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH
|
|||
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
|
||||
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/k0kubun/pp v3.0.1+incompatible h1:3tqvf7QgUnZ5tXO6pNAZlrvHgl6DvifjDrd9g2S9Z40=
|
||||
github.com/k0kubun/pp v3.0.1+incompatible/go.mod h1:GWse8YhT0p8pT4ir3ZgBbfZild3tgzSScAn6HmfYukg=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
|
||||
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
|
||||
github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE=
|
||||
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
|
@ -38,6 +44,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
|
||||
golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 h1:DH4skfRX4EBpamg7iV4ZlCpblAHI6s6TDM39bFZumv8=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
|
||||
|
|
|
@ -55,6 +55,11 @@ func (m Raw) MarshalJSON() ([]byte, error) {
|
|||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Raw) UnmarshalJSON(data []byte) error {
|
||||
*m = append((*m)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m Raw) String() string {
|
||||
return string(m)
|
||||
}
|
||||
|
|
|
@ -77,8 +77,6 @@ func (c *Conn) Listen() <-chan Event {
|
|||
}
|
||||
|
||||
func (c *Conn) readLoop(ch chan Event) {
|
||||
defer close(ch)
|
||||
|
||||
for {
|
||||
ctx, cancel := context.WithTimeout(
|
||||
context.Background(), c.ReadTimeout)
|
||||
|
@ -124,15 +122,17 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
w, err := c.Writer(ctx, websocket.MessageBinary)
|
||||
// TODO: zlib stream
|
||||
|
||||
w, err := c.Writer(ctx, websocket.MessageText)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to get WS writer")
|
||||
}
|
||||
|
||||
defer w.Close()
|
||||
|
||||
// Compress with zlib by default.
|
||||
w = zlib.NewWriter(w)
|
||||
// Compress with zlib by default NOT.
|
||||
// w = zlib.NewWriter(w)
|
||||
|
||||
_, err = w.Write(b)
|
||||
return err
|
||||
|
|
|
@ -13,3 +13,11 @@ func NewSendLimiter() *rate.Limiter {
|
|||
func NewDialLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Every(5*time.Second), 1)
|
||||
}
|
||||
|
||||
func NewIdentityLimiter() *rate.Limiter {
|
||||
return NewDialLimiter() // same
|
||||
}
|
||||
|
||||
func NewGlobalIdentityLimiter() *rate.Limiter {
|
||||
return rate.NewLimiter(rate.Every(24*time.Hour), 1000)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue