From 794342ab7178fb8083d8b191f5c2a47988c8d678 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Mon, 8 Jun 2020 23:02:51 -0700 Subject: [PATCH] Implemented Nickname cancellation using context, the good way --- channel.go | 55 +++++++++++++++++++++++++++++++++++++++++------------- service.go | 10 +++++----- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/channel.go b/channel.go index 66fb373..2b271f6 100644 --- a/channel.go +++ b/channel.go @@ -1,6 +1,7 @@ package mock import ( + "context" "math/rand" "strconv" "strings" @@ -40,7 +41,8 @@ type Channel struct { // up to about 12 or so. check sameAuthorLimit. incrAuthor uint8 - busyWg sync.WaitGroup + // single-use write-once context, written on every JoinServer + ctx context.Context } var ( @@ -61,15 +63,25 @@ func (ch *Channel) Name() text.Rich { return text.Rich{Content: ch.name} } +// Nickname sets the labeler to the nickname. It simulates heavy IO. This +// function stops as cancel is called in JoinServer, as Nickname is specially +// for that. func (ch *Channel) Nickname(labeler cchat.LabelContainer) error { - // Simulate IO. - simulateAustralianInternet() + // Borrow the parent's context and stop fetching if the context expires. + ctx, cancel := context.WithCancel(ch.ctx) + defer cancel() + + // Simulate IO with cancellation. Ignore the error if it's a simulated time + // out, else return. + if err := simulateAustralianInternetCtx(ctx); err != nil && err != ErrTimedOut { + return err + } labeler.SetLabel(ch.username) return nil } -func (ch *Channel) JoinServer(container cchat.MessagesContainer) (func(), error) { +func (ch *Channel) JoinServer(container cchat.MessagesContainer) (stop func(), err error) { // Is this a fresh channel? If yes, generate messages with some IO latency. if len(ch.messages) == 0 || ch.messageixs == nil { // Simulate IO. @@ -94,8 +106,8 @@ func (ch *Channel) JoinServer(container cchat.MessagesContainer) (func(), error) } } - // Initialize channels for use. - doneCh := make(chan struct{}) + // Initialize context for cancellation. + ch.ctx, stop = context.WithCancel(context.Background()) go func() { ticker := time.NewTicker(4 * time.Second) @@ -126,13 +138,13 @@ func (ch *Channel) JoinServer(container cchat.MessagesContainer) (func(), error) var old = ch.randomOldMsg() ch.deleteMessage(MessageHeader{old.id, time.Now()}, container) - case <-doneCh: + case <-ch.ctx.Done(): return } } }() - return func() { doneCh <- struct{}{} }, nil + return } func (ch *Channel) RawMessageContent(id string) (string, error) { @@ -272,8 +284,8 @@ func (ch *Channel) nextID() (id uint32) { } func (ch *Channel) SendMessage(msg cchat.SendableMessage) error { - if simulateAustralianInternet() { - return errors.New("Failed to send message: Australian Internet unsupported.") + if err := simulateAustralianInternet(); err != nil { + return errors.Wrap(err, "Failed to send message") } go func() { @@ -360,12 +372,29 @@ func randClamp(min, max int) int { return rand.Intn(max-min) + min } +// ErrTimedOut is returned when the simulated IO decides to fail. +var ErrTimedOut = errors.New("Australian Internet unsupported.") + // simulate network latency -func simulateAustralianInternet() (lost bool) { +func simulateAustralianInternet() error { + return simulateAustralianInternetCtx(context.Background()) +} + +func simulateAustralianInternetCtx(ctx context.Context) (err error) { var ms = randClamp(internetMinLatency, internetMaxLatency) - <-time.After(time.Duration(ms) * time.Millisecond) + + select { + case <-time.After(time.Duration(ms) * time.Millisecond): + // noop + case <-ctx.Done(): + return ctx.Err() + } // because australia, drop packet 20% of the time if internetCanFail is // true. - return internetCanFail && rand.Intn(100) < 20 + if internetCanFail && rand.Intn(100) < 20 { + return ErrTimedOut + } + + return nil } diff --git a/service.go b/service.go index 02703d4..04cdeeb 100644 --- a/service.go +++ b/service.go @@ -3,13 +3,13 @@ package mock import ( "encoding/json" - "errors" "strconv" "time" "github.com/diamondburned/cchat" "github.com/diamondburned/cchat/services" "github.com/diamondburned/cchat/text" + "github.com/pkg/errors" ) func init() { @@ -32,8 +32,8 @@ func (s Service) Name() text.Rich { } func (s Service) RestoreSession(storage map[string]string) (cchat.Session, error) { - if simulateAustralianInternet() { - return nil, errors.New("Restore failed: server machine broke") + if err := simulateAustralianInternet(); err != nil { + return nil, errors.Wrap(err, "Restore failed") } username, ok := storage["username"] @@ -70,8 +70,8 @@ func (Authenticator) AuthenticateForm() []cchat.AuthenticateEntry { func (Authenticator) Authenticate(form []string) (cchat.Session, error) { // SLOW IO TIME. - if simulateAustralianInternet() { - return nil, errors.New("Authentication timed out.") + if err := simulateAustralianInternet(); err != nil { + return nil, errors.Wrap(err, "Authentication failed") } return newSession(form[0]), nil