diff --git a/gateway/gateway.go b/gateway/gateway.go index f3db18a..ad6e198 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -14,13 +14,14 @@ import ( "sync" "time" + "github.com/gorilla/websocket" + "github.com/pkg/errors" + "github.com/diamondburned/arikawa/v3/api" "github.com/diamondburned/arikawa/v3/internal/moreatomic" "github.com/diamondburned/arikawa/v3/utils/json" "github.com/diamondburned/arikawa/v3/utils/json/option" "github.com/diamondburned/arikawa/v3/utils/wsutil" - "github.com/gorilla/websocket" - "github.com/pkg/errors" ) var ( @@ -359,9 +360,6 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) { // this function over Start(). The given context provides cancellation and // timeout. func (g *Gateway) Open(ctx context.Context) error { - ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout) - defer cancel() - // Reconnect to the Gateway if err := g.WS.Dial(ctx); err != nil { return errors.Wrap(err, "failed to Reconnect") diff --git a/utils/wsutil/ws.go b/utils/wsutil/ws.go index 036dcaf..cfa3467 100644 --- a/utils/wsutil/ws.go +++ b/utils/wsutil/ws.go @@ -44,11 +44,10 @@ type Websocket struct { sendLimiter *rate.Limiter dialLimiter *rate.Limiter - // Constants. These must not be changed after the Websocket instance is used - // once, as they are not thread-safe. - - // Timeout for connecting and writing to the Websocket, uses default - // WSTimeout (global). + // Timeout is the default timeout used if a context with no deadline is + // given to Dial. + // + // It must not be changed after the Websocket is used once. Timeout time.Duration } @@ -72,12 +71,14 @@ func NewCustom(conn Connection, addr string) *Websocket { } // Dial waits until the rate limiter allows then dials the websocket. +// +// If the passed context has no deadline, Dial will wrap it in a +// context.WithTimeout using ws.Timeout as timeout. func (ws *Websocket) Dial(ctx context.Context) error { - if ws.Timeout > 0 { - tctx, cancel := context.WithTimeout(ctx, ws.Timeout) + if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, ws.Timeout) defer cancel() - - ctx = tctx } if err := ws.dialLimiter.Wait(ctx); err != nil {