diff --git a/api/send.go b/api/send.go index c6ecefd..1b2f986 100644 --- a/api/send.go +++ b/api/send.go @@ -174,11 +174,7 @@ func (c *Client) SendMessageComplex( return msg, c.RequestJSON(&msg, "POST", URL, httputil.WithJSONBody(data)) } - writer := func(mw *multipart.Writer) error { - return data.WriteMultipart(mw) - } - - resp, err := c.MeanwhileMultipart(writer, "POST", URL) + resp, err := c.MeanwhileMultipart(data.WriteMultipart, "POST", URL) if err != nil { return nil, err } diff --git a/bot/ctx.go b/bot/ctx.go index 1e5d2dc..c278820 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -382,7 +382,7 @@ func (ctx *Context) Start() func() { }) if err != nil { - ctx.ErrorLogger(err) + ctx.ErrorLogger(errors.Wrap(err, "failed to send message")) // TODO: there ought to be a better way lol } diff --git a/session/session.go b/session/session.go index 484f394..9bc890a 100644 --- a/session/session.go +++ b/session/session.go @@ -4,6 +4,7 @@ package session import ( + "context" "sync" "github.com/pkg/errors" @@ -36,14 +37,22 @@ type Session struct { // Command handler with inherited methods. *handler.Handler - // MFA only fields - MFA bool - Ticket string + // internal state to not be copied around. + *sessionState +} +// sessionState contains fields crucial for controlling the state of session. It +// should not be copied around. +type sessionState struct { hstop chan struct{} wstop sync.Once } +func (state *sessionState) Reset() { + state.hstop = make(chan struct{}) + state.wstop = sync.Once{} +} + func NewWithIntents(token string, intents ...gateway.Intents) (*Session, error) { g, err := gateway.NewGatewayWithIntents(token, intents...) if err != nil { @@ -99,15 +108,15 @@ func NewWithGateway(gw *gateway.Gateway) *Session { return &Session{ Gateway: gw, // Nab off gateway's token - Client: api.NewClient(gw.Identifier.Token), - Handler: handler.New(), + Client: api.NewClient(gw.Identifier.Token), + Handler: handler.New(), + sessionState: &sessionState{}, } } func (s *Session) Open() error { // Start the handler beforehand so no events are missed. - s.hstop = make(chan struct{}) - s.wstop = sync.Once{} + s.sessionState.Reset() go s.startHandler() // Set the AfterClose's handler. @@ -124,6 +133,18 @@ func (s *Session) Open() error { return nil } +// WithContext returns a shallow copy of Session with the context replaced in +// the API client. All methods called on the returned Session will use this +// given context. +// +// This method is thread-safe only after Open and before Close are called. Open +// and Close should not be called on the returned Session. +func (s *Session) WithContext(ctx context.Context) *Session { + cpy := *s + cpy.Client = s.Client.WithContext(ctx) + return &cpy +} + func (s *Session) startHandler() { for { select { @@ -137,7 +158,7 @@ func (s *Session) startHandler() { func (s *Session) Close() error { // Stop the event handler - s.wstop.Do(func() { s.hstop <- struct{}{} }) + s.wstop.Do(func() { close(s.hstop) }) // Close the websocket return s.Gateway.Close() } diff --git a/state/state.go b/state/state.go index 234fe3b..c7e9d9b 100644 --- a/state/state.go +++ b/state/state.go @@ -62,8 +62,8 @@ type State struct { // *: State doesn't actually keep track of pinned messages. - // Ready is not updated by the state. - Ready gateway.ReadyEvent + readyMu sync.Mutex + ready gateway.ReadyEvent // StateLog logs all errors that come from the state cache. This includes // not found errors. Defaults to a no-op, as state errors aren't that @@ -142,11 +142,25 @@ func NewFromSession(s *session.Session, store Store) (*State, error) { // method is thread-safe. func (s *State) WithContext(ctx context.Context) *State { copied := *s - copied.Client = copied.Client.WithContext(ctx) + copied.Session = s.Session.WithContext(ctx) return &copied } +// Ready takes in a callback to access the Ready event in a thread-safe manner. +// As it acquires a mutex for thread-safety, the callback shouldn't do anything +// blocking to prevent stalling the state updates. It should also not reference +// or copy the Ready instance, as that instance will not be thread-safe. +// +// Note that the Ready that passed in will never be nil; if Ready events are not +// received yet, then the pointer will point to State's zero-value Ready +// instance. +func (s *State) Ready(fn func(*gateway.ReadyEvent)) { + s.readyMu.Lock() + fn(&s.ready) + s.readyMu.Unlock() +} + //// Helper methods func (s *State) AuthorDisplayName(message *gateway.MessageCreateEvent) string { diff --git a/state/state_events.go b/state/state_events.go index cb389fb..2e4afcc 100644 --- a/state/state_events.go +++ b/state/state_events.go @@ -50,6 +50,11 @@ func (s *State) hookSession() { func (s *State) onEvent(iface interface{}) { switch ev := iface.(type) { case *gateway.ReadyEvent: + // Acquire the ready mutex for the rest of these update calls, as they + // will be accessing ready's fields. + s.readyMu.Lock() + s.ready = *ev + // Reset the store before proceeding. if resetter, ok := s.Store.(StoreResetter); ok { if err := resetter.Reset(); err != nil { @@ -57,9 +62,6 @@ func (s *State) onEvent(iface interface{}) { } } - // Set Ready to the state - s.Ready = *ev - // Handle presences for _, p := range ev.Presences { if err := s.Store.PresenceSet(0, p); err != nil { @@ -84,6 +86,9 @@ func (s *State) onEvent(iface interface{}) { s.stateErr(err, "failed to set self in state") } + // Release the ready mutex only after we're done with everything. + s.readyMu.Unlock() + case *gateway.GuildCreateEvent: s.batchLog(storeGuildCreate(s.Store, ev)) @@ -268,17 +273,23 @@ func (s *State) onEvent(iface interface{}) { case *gateway.SessionsReplaceEvent: case *gateway.UserGuildSettingsUpdateEvent: - for i, ugs := range s.Ready.UserGuildSettings { + s.readyMu.Lock() + for i, ugs := range s.ready.UserGuildSettings { if ugs.GuildID == ev.GuildID { - s.Ready.UserGuildSettings[i] = ev.UserGuildSettings + s.ready.UserGuildSettings[i] = ev.UserGuildSettings } } + s.readyMu.Unlock() case *gateway.UserSettingsUpdateEvent: - s.Ready.Settings = &ev.UserSettings + s.readyMu.Lock() + s.ready.Settings = &ev.UserSettings + s.readyMu.Unlock() case *gateway.UserNoteUpdateEvent: - s.Ready.Notes[ev.ID] = ev.Note + s.readyMu.Lock() + s.ready.Notes[ev.ID] = ev.Note + s.readyMu.Unlock() case *gateway.UserUpdateEvent: if err := s.Store.MyselfSet(ev.User); err != nil { diff --git a/utils/httputil/client.go b/utils/httputil/client.go index ea95892..a1d2dbb 100644 --- a/utils/httputil/client.go +++ b/utils/httputil/client.go @@ -6,7 +6,9 @@ import ( "bytes" "context" "io" + "log" "mime/multipart" + "runtime/debug" "github.com/pkg/errors" @@ -56,6 +58,7 @@ func (c *Client) Copy() *Client { // WithContext returns a client copy of the client with the given context. func (c *Client) WithContext(ctx context.Context) *Client { + log.Println("Setting request; stack trace:", string(debug.Stack())) c = c.Copy() c.context = ctx return c @@ -89,24 +92,10 @@ func (c *Client) MeanwhileMultipart( writer func(*multipart.Writer) error, method, url string, opts ...RequestOption) (httpdriver.Response, error) { - // We want to cancel the request if our bodyWriter fails. - ctx, cancel := context.WithCancel(c.context) - defer cancel() - r, w := io.Pipe() body := multipart.NewWriter(w) - var bgErr error - - go func() { - if err := writer(body); err != nil { - bgErr = err - cancel() - } - - // Close the writer so the body gets flushed to the HTTP reader. - w.Close() - }() + go func() { w.CloseWithError(writer(body)) }() // Prepend the multipart writer and the correct Content-Type header options. opts = PrependOptions( @@ -116,11 +105,7 @@ func (c *Client) MeanwhileMultipart( ) // Request with the current client and our own context: - resp, err := c.WithContext(ctx).Request(method, url, opts...) - if err != nil && bgErr != nil { - return nil, bgErr - } - return resp, err + return c.Request(method, url, opts...) } func (c *Client) FastRequest(method, url string, opts ...RequestOption) error { @@ -176,7 +161,7 @@ func (c *Client) Request(method, url string, opts ...RequestOption) (httpdriver. fn(q, nil) } // Exit after cleaning everything up. - return nil, errors.Wrap(err, "failed to apply options") + return nil, errors.Wrap(err, "failed to apply http request options") } r, doErr = c.Client.Do(q)