mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 10:43:30 +00:00
State: Fixed erroneous context setting and races in Ready
This commit is contained in:
parent
f6e270ae9c
commit
6cbe95d2b3
|
@ -174,11 +174,7 @@ func (c *Client) SendMessageComplex(
|
|||
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(data))
|
||||
}
|
||||
|
||||
writer := func(mw *multipart.Writer) error {
|
||||
return data.WriteMultipart(mw)
|
||||
}
|
||||
|
||||
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
|
||||
resp, err := c.MeanwhileMultipart(data.WriteMultipart, "POST", URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -382,7 +382,7 @@ func (ctx *Context) Start() func() {
|
|||
})
|
||||
|
||||
if err != nil {
|
||||
ctx.ErrorLogger(err)
|
||||
ctx.ErrorLogger(errors.Wrap(err, "failed to send message"))
|
||||
|
||||
// TODO: there ought to be a better way lol
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -36,14 +37,22 @@ type Session struct {
|
|||
// Command handler with inherited methods.
|
||||
*handler.Handler
|
||||
|
||||
// MFA only fields
|
||||
MFA bool
|
||||
Ticket string
|
||||
// internal state to not be copied around.
|
||||
*sessionState
|
||||
}
|
||||
|
||||
// sessionState contains fields crucial for controlling the state of session. It
|
||||
// should not be copied around.
|
||||
type sessionState struct {
|
||||
hstop chan struct{}
|
||||
wstop sync.Once
|
||||
}
|
||||
|
||||
func (state *sessionState) Reset() {
|
||||
state.hstop = make(chan struct{})
|
||||
state.wstop = sync.Once{}
|
||||
}
|
||||
|
||||
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
|
||||
g, err := gateway.NewGatewayWithIntents(token, intents...)
|
||||
if err != nil {
|
||||
|
@ -99,15 +108,15 @@ func NewWithGateway(gw *gateway.Gateway) *Session {
|
|||
return &Session{
|
||||
Gateway: gw,
|
||||
// Nab off gateway's token
|
||||
Client: api.NewClient(gw.Identifier.Token),
|
||||
Handler: handler.New(),
|
||||
Client: api.NewClient(gw.Identifier.Token),
|
||||
Handler: handler.New(),
|
||||
sessionState: &sessionState{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) Open() error {
|
||||
// Start the handler beforehand so no events are missed.
|
||||
s.hstop = make(chan struct{})
|
||||
s.wstop = sync.Once{}
|
||||
s.sessionState.Reset()
|
||||
go s.startHandler()
|
||||
|
||||
// Set the AfterClose's handler.
|
||||
|
@ -124,6 +133,18 @@ func (s *Session) Open() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// WithContext returns a shallow copy of Session with the context replaced in
|
||||
// the API client. All methods called on the returned Session will use this
|
||||
// given context.
|
||||
//
|
||||
// This method is thread-safe only after Open and before Close are called. Open
|
||||
// and Close should not be called on the returned Session.
|
||||
func (s *Session) WithContext(ctx context.Context) *Session {
|
||||
cpy := *s
|
||||
cpy.Client = s.Client.WithContext(ctx)
|
||||
return &cpy
|
||||
}
|
||||
|
||||
func (s *Session) startHandler() {
|
||||
for {
|
||||
select {
|
||||
|
@ -137,7 +158,7 @@ func (s *Session) startHandler() {
|
|||
|
||||
func (s *Session) Close() error {
|
||||
// Stop the event handler
|
||||
s.wstop.Do(func() { s.hstop <- struct{}{} })
|
||||
s.wstop.Do(func() { close(s.hstop) })
|
||||
// Close the websocket
|
||||
return s.Gateway.Close()
|
||||
}
|
||||
|
|
|
@ -62,8 +62,8 @@ type State struct {
|
|||
|
||||
// *: State doesn't actually keep track of pinned messages.
|
||||
|
||||
// Ready is not updated by the state.
|
||||
Ready gateway.ReadyEvent
|
||||
readyMu sync.Mutex
|
||||
ready gateway.ReadyEvent
|
||||
|
||||
// StateLog logs all errors that come from the state cache. This includes
|
||||
// not found errors. Defaults to a no-op, as state errors aren't that
|
||||
|
@ -142,11 +142,25 @@ func NewFromSession(s *session.Session, store Store) (*State, error) {
|
|||
// method is thread-safe.
|
||||
func (s *State) WithContext(ctx context.Context) *State {
|
||||
copied := *s
|
||||
copied.Client = copied.Client.WithContext(ctx)
|
||||
copied.Session = s.Session.WithContext(ctx)
|
||||
|
||||
return &copied
|
||||
}
|
||||
|
||||
// Ready takes in a callback to access the Ready event in a thread-safe manner.
|
||||
// As it acquires a mutex for thread-safety, the callback shouldn't do anything
|
||||
// blocking to prevent stalling the state updates. It should also not reference
|
||||
// or copy the Ready instance, as that instance will not be thread-safe.
|
||||
//
|
||||
// Note that the Ready that passed in will never be nil; if Ready events are not
|
||||
// received yet, then the pointer will point to State's zero-value Ready
|
||||
// instance.
|
||||
func (s *State) Ready(fn func(*gateway.ReadyEvent)) {
|
||||
s.readyMu.Lock()
|
||||
fn(&s.ready)
|
||||
s.readyMu.Unlock()
|
||||
}
|
||||
|
||||
//// Helper methods
|
||||
|
||||
func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {
|
||||
|
|
|
@ -50,6 +50,11 @@ func (s *State) hookSession() {
|
|||
func (s *State) onEvent(iface interface{}) {
|
||||
switch ev := iface.(type) {
|
||||
case *gateway.ReadyEvent:
|
||||
// Acquire the ready mutex for the rest of these update calls, as they
|
||||
// will be accessing ready's fields.
|
||||
s.readyMu.Lock()
|
||||
s.ready = *ev
|
||||
|
||||
// Reset the store before proceeding.
|
||||
if resetter, ok := s.Store.(StoreResetter); ok {
|
||||
if err := resetter.Reset(); err != nil {
|
||||
|
@ -57,9 +62,6 @@ func (s *State) onEvent(iface interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
// Set Ready to the state
|
||||
s.Ready = *ev
|
||||
|
||||
// Handle presences
|
||||
for _, p := range ev.Presences {
|
||||
if err := s.Store.PresenceSet(0, p); err != nil {
|
||||
|
@ -84,6 +86,9 @@ func (s *State) onEvent(iface interface{}) {
|
|||
s.stateErr(err, "failed to set self in state")
|
||||
}
|
||||
|
||||
// Release the ready mutex only after we're done with everything.
|
||||
s.readyMu.Unlock()
|
||||
|
||||
case *gateway.GuildCreateEvent:
|
||||
s.batchLog(storeGuildCreate(s.Store, ev))
|
||||
|
||||
|
@ -268,17 +273,23 @@ func (s *State) onEvent(iface interface{}) {
|
|||
case *gateway.SessionsReplaceEvent:
|
||||
|
||||
case *gateway.UserGuildSettingsUpdateEvent:
|
||||
for i, ugs := range s.Ready.UserGuildSettings {
|
||||
s.readyMu.Lock()
|
||||
for i, ugs := range s.ready.UserGuildSettings {
|
||||
if ugs.GuildID == ev.GuildID {
|
||||
s.Ready.UserGuildSettings[i] = ev.UserGuildSettings
|
||||
s.ready.UserGuildSettings[i] = ev.UserGuildSettings
|
||||
}
|
||||
}
|
||||
s.readyMu.Unlock()
|
||||
|
||||
case *gateway.UserSettingsUpdateEvent:
|
||||
s.Ready.Settings = &ev.UserSettings
|
||||
s.readyMu.Lock()
|
||||
s.ready.Settings = &ev.UserSettings
|
||||
s.readyMu.Unlock()
|
||||
|
||||
case *gateway.UserNoteUpdateEvent:
|
||||
s.Ready.Notes[ev.ID] = ev.Note
|
||||
s.readyMu.Lock()
|
||||
s.ready.Notes[ev.ID] = ev.Note
|
||||
s.readyMu.Unlock()
|
||||
|
||||
case *gateway.UserUpdateEvent:
|
||||
if err := s.Store.MyselfSet(ev.User); err != nil {
|
||||
|
|
|
@ -6,7 +6,9 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"log"
|
||||
"mime/multipart"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
|
@ -56,6 +58,7 @@ func (c *Client) Copy() *Client {
|
|||
|
||||
// WithContext returns a client copy of the client with the given context.
|
||||
func (c *Client) WithContext(ctx context.Context) *Client {
|
||||
log.Println("Setting request; stack trace:", string(debug.Stack()))
|
||||
c = c.Copy()
|
||||
c.context = ctx
|
||||
return c
|
||||
|
@ -89,24 +92,10 @@ func (c *Client) MeanwhileMultipart(
|
|||
writer func(*multipart.Writer) error,
|
||||
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
|
||||
|
||||
// We want to cancel the request if our bodyWriter fails.
|
||||
ctx, cancel := context.WithCancel(c.context)
|
||||
defer cancel()
|
||||
|
||||
r, w := io.Pipe()
|
||||
body := multipart.NewWriter(w)
|
||||
|
||||
var bgErr error
|
||||
|
||||
go func() {
|
||||
if err := writer(body); err != nil {
|
||||
bgErr = err
|
||||
cancel()
|
||||
}
|
||||
|
||||
// Close the writer so the body gets flushed to the HTTP reader.
|
||||
w.Close()
|
||||
}()
|
||||
go func() { w.CloseWithError(writer(body)) }()
|
||||
|
||||
// Prepend the multipart writer and the correct Content-Type header options.
|
||||
opts = PrependOptions(
|
||||
|
@ -116,11 +105,7 @@ func (c *Client) MeanwhileMultipart(
|
|||
)
|
||||
|
||||
// Request with the current client and our own context:
|
||||
resp, err := c.WithContext(ctx).Request(method, url, opts...)
|
||||
if err != nil && bgErr != nil {
|
||||
return nil, bgErr
|
||||
}
|
||||
return resp, err
|
||||
return c.Request(method, url, opts...)
|
||||
}
|
||||
|
||||
func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
|
||||
|
@ -176,7 +161,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
|
|||
fn(q, nil)
|
||||
}
|
||||
// Exit after cleaning everything up.
|
||||
return nil, errors.Wrap(err, "failed to apply options")
|
||||
return nil, errors.Wrap(err, "failed to apply http request options")
|
||||
}
|
||||
|
||||
r, doErr = c.Client.Do(q)
|
||||
|
|
Loading…
Reference in a new issue