1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-30 10:43:30 +00:00

State: Fixed erroneous context setting and races in Ready

This commit is contained in:
diamondburned 2020-11-14 15:30:18 -08:00
parent f6e270ae9c
commit 6cbe95d2b3
6 changed files with 72 additions and 45 deletions

View file

@ -174,11 +174,7 @@ func (c *Client) SendMessageComplex(
return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(data))
}
writer := func(mw *multipart.Writer) error {
return data.WriteMultipart(mw)
}
resp, err := c.MeanwhileMultipart(writer, "POST", URL)
resp, err := c.MeanwhileMultipart(data.WriteMultipart, "POST", URL)
if err != nil {
return nil, err
}

View file

@ -382,7 +382,7 @@ func (ctx *Context) Start() func() {
})
if err != nil {
ctx.ErrorLogger(err)
ctx.ErrorLogger(errors.Wrap(err, "failed to send message"))
// TODO: there ought to be a better way lol
}

View file

@ -4,6 +4,7 @@
package session
import (
"context"
"sync"
"github.com/pkg/errors"
@ -36,14 +37,22 @@ type Session struct {
// Command handler with inherited methods.
*handler.Handler
// MFA only fields
MFA bool
Ticket string
// internal state to not be copied around.
*sessionState
}
// sessionState contains fields crucial for controlling the state of session. It
// should not be copied around.
type sessionState struct {
hstop chan struct{}
wstop sync.Once
}
func (state *sessionState) Reset() {
state.hstop = make(chan struct{})
state.wstop = sync.Once{}
}
func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) {
g, err := gateway.NewGatewayWithIntents(token, intents...)
if err != nil {
@ -99,15 +108,15 @@ func NewWithGateway(gw *gateway.Gateway) *Session {
return &Session{
Gateway: gw,
// Nab off gateway's token
Client: api.NewClient(gw.Identifier.Token),
Handler: handler.New(),
Client: api.NewClient(gw.Identifier.Token),
Handler: handler.New(),
sessionState: &sessionState{},
}
}
func (s *Session) Open() error {
// Start the handler beforehand so no events are missed.
s.hstop = make(chan struct{})
s.wstop = sync.Once{}
s.sessionState.Reset()
go s.startHandler()
// Set the AfterClose's handler.
@ -124,6 +133,18 @@ func (s *Session) Open() error {
return nil
}
// WithContext returns a shallow copy of Session with the context replaced in
// the API client. All methods called on the returned Session will use this
// given context.
//
// This method is thread-safe only after Open and before Close are called. Open
// and Close should not be called on the returned Session.
func (s *Session) WithContext(ctx context.Context) *Session {
cpy := *s
cpy.Client = s.Client.WithContext(ctx)
return &cpy
}
func (s *Session) startHandler() {
for {
select {
@ -137,7 +158,7 @@ func (s *Session) startHandler() {
func (s *Session) Close() error {
// Stop the event handler
s.wstop.Do(func() { s.hstop <- struct{}{} })
s.wstop.Do(func() { close(s.hstop) })
// Close the websocket
return s.Gateway.Close()
}

View file

@ -62,8 +62,8 @@ type State struct {
// *: State doesn't actually keep track of pinned messages.
// Ready is not updated by the state.
Ready gateway.ReadyEvent
readyMu sync.Mutex
ready gateway.ReadyEvent
// StateLog logs all errors that come from the state cache. This includes
// not found errors. Defaults to a no-op, as state errors aren't that
@ -142,11 +142,25 @@ func NewFromSession(s *session.Session, store Store) (*State, error) {
// method is thread-safe.
func (s *State) WithContext(ctx context.Context) *State {
copied := *s
copied.Client = copied.Client.WithContext(ctx)
copied.Session = s.Session.WithContext(ctx)
return &copied
}
// Ready takes in a callback to access the Ready event in a thread-safe manner.
// As it acquires a mutex for thread-safety, the callback shouldn't do anything
// blocking to prevent stalling the state updates. It should also not reference
// or copy the Ready instance, as that instance will not be thread-safe.
//
// Note that the Ready that passed in will never be nil; if Ready events are not
// received yet, then the pointer will point to State's zero-value Ready
// instance.
func (s *State) Ready(fn func(*gateway.ReadyEvent)) {
s.readyMu.Lock()
fn(&s.ready)
s.readyMu.Unlock()
}
//// Helper methods
func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string {

View file

@ -50,6 +50,11 @@ func (s *State) hookSession() {
func (s *State) onEvent(iface interface{}) {
switch ev := iface.(type) {
case *gateway.ReadyEvent:
// Acquire the ready mutex for the rest of these update calls, as they
// will be accessing ready's fields.
s.readyMu.Lock()
s.ready = *ev
// Reset the store before proceeding.
if resetter, ok := s.Store.(StoreResetter); ok {
if err := resetter.Reset(); err != nil {
@ -57,9 +62,6 @@ func (s *State) onEvent(iface interface{}) {
}
}
// Set Ready to the state
s.Ready = *ev
// Handle presences
for _, p := range ev.Presences {
if err := s.Store.PresenceSet(0, p); err != nil {
@ -84,6 +86,9 @@ func (s *State) onEvent(iface interface{}) {
s.stateErr(err, "failed to set self in state")
}
// Release the ready mutex only after we're done with everything.
s.readyMu.Unlock()
case *gateway.GuildCreateEvent:
s.batchLog(storeGuildCreate(s.Store, ev))
@ -268,17 +273,23 @@ func (s *State) onEvent(iface interface{}) {
case *gateway.SessionsReplaceEvent:
case *gateway.UserGuildSettingsUpdateEvent:
for i, ugs := range s.Ready.UserGuildSettings {
s.readyMu.Lock()
for i, ugs := range s.ready.UserGuildSettings {
if ugs.GuildID == ev.GuildID {
s.Ready.UserGuildSettings[i] = ev.UserGuildSettings
s.ready.UserGuildSettings[i] = ev.UserGuildSettings
}
}
s.readyMu.Unlock()
case *gateway.UserSettingsUpdateEvent:
s.Ready.Settings = &ev.UserSettings
s.readyMu.Lock()
s.ready.Settings = &ev.UserSettings
s.readyMu.Unlock()
case *gateway.UserNoteUpdateEvent:
s.Ready.Notes[ev.ID] = ev.Note
s.readyMu.Lock()
s.ready.Notes[ev.ID] = ev.Note
s.readyMu.Unlock()
case *gateway.UserUpdateEvent:
if err := s.Store.MyselfSet(ev.User); err != nil {

View file

@ -6,7 +6,9 @@ import (
"bytes"
"context"
"io"
"log"
"mime/multipart"
"runtime/debug"
"github.com/pkg/errors"
@ -56,6 +58,7 @@ func (c *Client) Copy() *Client {
// WithContext returns a client copy of the client with the given context.
func (c *Client) WithContext(ctx context.Context) *Client {
log.Println("Setting request; stack trace:", string(debug.Stack()))
c = c.Copy()
c.context = ctx
return c
@ -89,24 +92,10 @@ func (c *Client) MeanwhileMultipart(
writer func(*multipart.Writer) error,
method, url string, opts ...RequestOption) (httpdriver.Response, error) {
// We want to cancel the request if our bodyWriter fails.
ctx, cancel := context.WithCancel(c.context)
defer cancel()
r, w := io.Pipe()
body := multipart.NewWriter(w)
var bgErr error
go func() {
if err := writer(body); err != nil {
bgErr = err
cancel()
}
// Close the writer so the body gets flushed to the HTTP reader.
w.Close()
}()
go func() { w.CloseWithError(writer(body)) }()
// Prepend the multipart writer and the correct Content-Type header options.
opts = PrependOptions(
@ -116,11 +105,7 @@ func (c *Client) MeanwhileMultipart(
)
// Request with the current client and our own context:
resp, err := c.WithContext(ctx).Request(method, url, opts...)
if err != nil && bgErr != nil {
return nil, bgErr
}
return resp, err
return c.Request(method, url, opts...)
}
func (c *Client) FastRequest(method, url string, opts ...RequestOption) error {
@ -176,7 +161,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver.
fn(q, nil)
}
// Exit after cleaning everything up.
return nil, errors.Wrap(err, "failed to apply options")
return nil, errors.Wrap(err, "failed to apply http request options")
}
r, doErr = c.Client.Do(q)