From 6dafb30401a6cb8fc4a44be51a6b0235c3673473 Mon Sep 17 00:00:00 2001 From: "diamondburned (Forefront)" Date: Tue, 7 Apr 2020 19:33:56 -0700 Subject: [PATCH] State: Added more wrappers for direct messaging channels --- gateway/pacemaker.go | 27 +++++++++++---------------- state/state.go | 36 ++++++++++++++++++++++++++++++++++++ state/store.go | 3 +++ state/store_default.go | 21 +++++++++++++++++++-- state/store_noop.go | 4 ++++ 5 files changed, 73 insertions(+), 18 deletions(-) diff --git a/gateway/pacemaker.go b/gateway/pacemaker.go index ca69be4..593bac0 100644 --- a/gateway/pacemaker.go +++ b/gateway/pacemaker.go @@ -26,7 +26,7 @@ type Pacemaker struct { // Event OnDead func() error - stop chan<- struct{} + stop chan struct{} death chan error } @@ -60,19 +60,14 @@ func (p *Pacemaker) Dead() bool { func (p *Pacemaker) Stop() { if p.stop != nil { - close(p.stop) - p.stop = nil + p.stop <- struct{}{} } } -func (p *Pacemaker) start(stop chan struct{}, wg *sync.WaitGroup) error { +func (p *Pacemaker) start() error { tick := time.NewTicker(p.Heartrate) defer tick.Stop() - if wg != nil { - defer wg.Done() - } - // Echo at least once p.Echo() @@ -89,7 +84,7 @@ func (p *Pacemaker) start(stop chan struct{}, wg *sync.WaitGroup) error { } select { - case <-stop: + case <-p.stop: return nil case <-tick.C: @@ -100,16 +95,16 @@ func (p *Pacemaker) start(stop chan struct{}, wg *sync.WaitGroup) error { // StartAsync starts the pacemaker asynchronously. The WaitGroup is optional. func (p *Pacemaker) StartAsync(wg *sync.WaitGroup) (death chan error) { p.death = make(chan error) + p.stop = make(chan struct{}) - stop := make(chan struct{}) - p.stop = stop - - if wg != nil { - wg.Add(1) - } + wg.Add(1) go func() { - p.death <- p.start(stop, wg) + p.death <- p.start() + // Mark the pacemaker loop as done. + wg.Done() + // Mark the stop channel as nil, so later Close() calls won't block forever. + p.stop = nil }() return p.death diff --git a/state/state.go b/state/state.go index 1594c8d..c925a8d 100644 --- a/state/state.go +++ b/state/state.go @@ -223,6 +223,42 @@ func (s *State) Channels(guildID discord.Snowflake) ([]discord.Channel, error) { return c, nil } +func (s *State) CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error) { + c, err := s.Store.CreatePrivateChannel(recipient) + if err == nil { + return c, nil + } + + c, err = s.Session.CreatePrivateChannel(recipient) + if err != nil { + return nil, err + } + + return c, s.Store.ChannelSet(c) +} + +func (s *State) PrivateChannels() ([]discord.Channel, error) { + c, err := s.Store.PrivateChannels() + if err == nil { + return c, nil + } + + c, err = s.Session.PrivateChannels() + if err != nil { + return nil, err + } + + for _, ch := range c { + ch := ch + + if err := s.Store.ChannelSet(&ch); err != nil { + return nil, err + } + } + + return c, nil +} + //// func (s *State) Emoji( diff --git a/state/store.go b/state/store.go index 4979e8b..6aebbbf 100644 --- a/state/store.go +++ b/state/store.go @@ -27,6 +27,9 @@ type StoreGetter interface { // Channel should check for both DM and guild channels. Channel(id discord.Snowflake) (*discord.Channel, error) Channels(guildID discord.Snowflake) ([]discord.Channel, error) + + // same API as (*api.Client) + CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error) PrivateChannels() ([]discord.Channel, error) Emoji(guildID, emojiID discord.Snowflake) (*discord.Emoji, error) diff --git a/state/store_default.go b/state/store_default.go index 2a81f8c..74784be 100644 --- a/state/store_default.go +++ b/state/store_default.go @@ -118,17 +118,34 @@ func (s *DefaultStore) Channels(guildID discord.Snowflake) ([]discord.Channel, e return append([]discord.Channel{}, chs...), nil } +// CreatePrivateChannel searches in the cache for a private channel. It makes no +// API calls. +func (s *DefaultStore) CreatePrivateChannel(recipient discord.Snowflake) (*discord.Channel, error) { + s.mut.Lock() + defer s.mut.Unlock() + + // slow way + for _, ch := range s.privates { + if ch.Type != discord.DirectMessage || len(ch.DMRecipients) < 1 { + continue + } + if ch.DMRecipients[0].ID == recipient { + return &(*ch), nil + } + } + return nil, ErrStoreNotFound +} + // PrivateChannels returns a list of Direct Message channels randomly ordered. func (s *DefaultStore) PrivateChannels() ([]discord.Channel, error) { s.mut.Lock() + defer s.mut.Unlock() var chs = make([]discord.Channel, 0, len(s.privates)) for _, ch := range s.privates { chs = append(chs, *ch) } - s.mut.Unlock() - return chs, nil } diff --git a/state/store_noop.go b/state/store_noop.go index ae0c255..94d9144 100644 --- a/state/store_noop.go +++ b/state/store_noop.go @@ -35,6 +35,10 @@ func (NoopStore) Channels(discord.Snowflake) ([]discord.Channel, error) { return nil, ErrNotImplemented } +func (NoopStore) CreatePrivateChannel(discord.Snowflake) (*discord.Channel, error) { + return nil, ErrNotImplemented +} + func (NoopStore) PrivateChannels() ([]discord.Channel, error) { return nil, ErrNotImplemented }