1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-10-04 00:29:04 +00:00

gateway: Fix Context overwrite in Gateway.Open (#285)

* Gateway: Fix Gateway.Open overwriting the context argument

* WSUtil: Remove max context timeout in Websocket.Dial

* WSUtil: Use Websocket.Timeout if a no-deadline context is given to .Dial

* WSUtil: Add doc to Websocket.Timeout clarifying that it must not be changed after use
This commit is contained in:
Maximilian von Lindern 2021-10-21 00:06:06 +02:00 committed by GitHub
parent 918cce64e9
commit 528281b739
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 14 deletions

View file

@ -14,13 +14,14 @@ import (
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"github.com/diamondburned/arikawa/v3/api"
"github.com/diamondburned/arikawa/v3/internal/moreatomic"
"github.com/diamondburned/arikawa/v3/utils/json"
"github.com/diamondburned/arikawa/v3/utils/json/option"
"github.com/diamondburned/arikawa/v3/utils/wsutil"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
var (
@ -359,9 +360,6 @@ func (g *Gateway) ReconnectCtx(ctx context.Context) (err error) {
// this function over Start(). The given context provides cancellation and
// timeout.
func (g *Gateway) Open(ctx context.Context) error {
ctx, cancel := context.WithTimeout(context.Background(), g.WSTimeout)
defer cancel()
// Reconnect to the Gateway
if err := g.WS.Dial(ctx); err != nil {
return errors.Wrap(err, "failed to Reconnect")

View file

@ -44,11 +44,10 @@ type Websocket struct {
sendLimiter *rate.Limiter
dialLimiter *rate.Limiter
// Constants. These must not be changed after the Websocket instance is used
// once, as they are not thread-safe.
// Timeout for connecting and writing to the Websocket, uses default
// WSTimeout (global).
// Timeout is the default timeout used if a context with no deadline is
// given to Dial.
//
// It must not be changed after the Websocket is used once.
Timeout time.Duration
}
@ -72,12 +71,14 @@ func NewCustom(conn Connection, addr string) *Websocket {
}
// Dial waits until the rate limiter allows then dials the websocket.
//
// If the passed context has no deadline, Dial will wrap it in a
// context.WithTimeout using ws.Timeout as timeout.
func (ws *Websocket) Dial(ctx context.Context) error {
if ws.Timeout > 0 {
tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
if _, ok := ctx.Deadline(); !ok && ws.Timeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, ws.Timeout)
defer cancel()
ctx = tctx
}
if err := ws.dialLimiter.Wait(ctx); err != nil {