diff --git a/discord/time.go b/discord/time.go index f965d38..cdd2a0a 100644 --- a/discord/time.go +++ b/discord/time.go @@ -38,6 +38,10 @@ func (t Timestamp) MarshalJSON() ([]byte, error) { return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil } +func (t Timestamp) Valid() bool { + return !time.Time(t).IsZero() +} + // type UnixTimestamp int64 diff --git a/gateway/events.go b/gateway/events.go index 1825d4c..4a3d798 100644 --- a/gateway/events.go +++ b/gateway/events.go @@ -161,9 +161,9 @@ type ( // Clients may only update their game status 5 times per 20 seconds. PresenceUpdateEvent discord.Presence TypingStartEvent struct { - ChannelID discord.Snowflake `json:"channel_id"` - UserID discord.Snowflake `json:"user_id"` - Timestamp discord.Timestamp `json:"timestamp"` + ChannelID discord.Snowflake `json:"channel_id"` + UserID discord.Snowflake `json:"user_id"` + Timestamp discord.UnixTimestamp `json:"timestamp"` GuildID discord.Snowflake `json:"guild_id,omitempty"` Member *discord.Member `json:"member,omitempty"` diff --git a/internal/wsutil/conn.go b/internal/wsutil/conn.go index 16b514e..4e5562b 100644 --- a/internal/wsutil/conn.go +++ b/internal/wsutil/conn.go @@ -12,7 +12,7 @@ import ( ) var WSBuffer = 12 -var WSReadLimit = 4096 // 4096 bytes +var WSReadLimit int64 = 8192000 // 8 MiB // Connection is an interface that abstracts around a generic Websocket driver. // This connection expects the driver to handle compression by itself. @@ -64,6 +64,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error { HTTPHeader: headers, }) + c.Conn.SetReadLimit(WSReadLimit) + go func() { c.readLoop(c.events) }() @@ -109,6 +111,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { // Probably a zlib payload z, err := zlib.NewReader(r) if err != nil { + c.CloseRead(ctx) return nil, errors.Wrap(err, "Failed to create a zlib reader") } @@ -117,24 +120,18 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) { r = z } - return ioutil.ReadAll(r) + b, err := ioutil.ReadAll(r) + if err != nil { + c.CloseRead(ctx) + return nil, err + } + + return b, nil } func (c *Conn) Send(ctx context.Context, b []byte) error { // TODO: zlib stream - - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return errors.Wrap(err, "Failed to get WS writer") - } - - defer w.Close() - - // Compress with zlib by default NOT. - // w = zlib.NewWriter(w) - - _, err = w.Write(b) - return err + return c.Write(ctx, websocket.MessageText, b) } func (c *Conn) Close(err error) error { diff --git a/state/state.go b/state/state.go index d4c47c7..1825900 100644 --- a/state/state.go +++ b/state/state.go @@ -4,6 +4,7 @@ package state import ( "log" + "sync" "github.com/diamondburned/arikawa/discord" "github.com/diamondburned/arikawa/gateway" @@ -37,6 +38,11 @@ type State struct { PreHandler *handler.Handler // default nil unhooker func() + + // List of channels with few messages, so it doesn't bother hitting the API + // again. + fewMessages []discord.Snowflake + fewMutex sync.Mutex } func NewFromSession(s *session.Session, store Store) (*State, error) { @@ -298,9 +304,28 @@ func (s *State) Message( // Messages fetches maximum 100 messages from the API, if it has to. There is no // limit if it's from the State storage. func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) { + // TODO: Think of a design that doesn't rely on MaxMessages(). + var maxMsgs = s.MaxMessages() + ms, err := s.Store.Messages(channelID) if err == nil { - return ms, nil + // If the state already has as many messages as it can, skip the API. + if maxMsgs <= len(ms) { + return ms, nil + } + + // Is the channel tiny? + s.fewMutex.Lock() + for _, ch := range s.fewMessages { + if ch == channelID { + // Yes, skip the state. + s.fewMutex.Unlock() + return ms, nil + } + } + + // No, fetch from the state. + s.fewMutex.Unlock() } ms, err = s.Session.Messages(channelID, 100) @@ -314,7 +339,18 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) } } - return ms, nil + if len(ms) < maxMsgs { + // Tiny channel, store this. + s.fewMutex.Lock() + s.fewMessages = append(s.fewMessages, channelID) + s.fewMutex.Unlock() + + return ms, nil + } + + // Since the latest messages are at the end and we already know the maxMsgs, + // we could slice this right away. + return ms[:maxMsgs], nil } //// diff --git a/state/store.go b/state/store.go index 61bfeca..245ba00 100644 --- a/state/store.go +++ b/state/store.go @@ -34,6 +34,7 @@ type StoreGetter interface { Message(channelID, messageID discord.Snowflake) (*discord.Message, error) Messages(channelID discord.Snowflake) ([]discord.Message, error) + MaxMessages() int // used to know if the state is filled or not. // These don't get fetched from the API, it's Gateway only. Presence(guildID, userID discord.Snowflake) (*discord.Presence, error) diff --git a/state/store_default.go b/state/store_default.go index 5b87c3d..51c9fc4 100644 --- a/state/store_default.go +++ b/state/store_default.go @@ -148,6 +148,11 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error { for i, ch := range chs { if ch.ID == channel.ID { + // Also from discordgo. + if channel.Permissions == nil { + channel.Permissions = ch.Permissions + } + // Found, just edit chs[i] = *channel @@ -289,11 +294,21 @@ func (s *DefaultStore) Guilds() ([]discord.Guild, error) { return gs, nil } -func (s *DefaultStore) GuildSet(g *discord.Guild) error { +func (s *DefaultStore) GuildSet(guild *discord.Guild) error { s.mut.Lock() - s.guilds[g.ID] = g - s.mut.Unlock() + defer s.mut.Unlock() + if g, ok := s.guilds[guild.ID]; ok { + // preserve state stuff + if guild.Roles == nil { + guild.Roles = g.Roles + } + if guild.Emojis == nil { + guild.Emojis = g.Emojis + } + } + + s.guilds[guild.ID] = guild return nil } @@ -425,25 +440,66 @@ func (s *DefaultStore) Messages( return ms, nil } +func (s *DefaultStore) MaxMessages() int { + return int(s.DefaultStoreOptions.MaxMessages) +} + func (s *DefaultStore) MessageSet(message *discord.Message) error { s.mut.Lock() defer s.mut.Unlock() ms, ok := s.messages[message.ChannelID] if !ok { - ms = make([]discord.Message, 0, int(s.MaxMessages)+1) + ms = make([]discord.Message, 0, s.MaxMessages()+1) } - // Append - ms = append(ms, *message) + // Check if we already have the message. + for i, m := range ms { + if m.ID == message.ID { + // Thanks, Discord. + if message.Content != "" { + m.Content = message.Content + } + if message.EditedTimestamp != nil { + m.EditedTimestamp = message.EditedTimestamp + } + if message.Mentions != nil { + m.Mentions = message.Mentions + } + if message.Embeds != nil { + m.Embeds = message.Embeds + } + if message.Attachments != nil { + m.Attachments = message.Attachments + } + if message.Timestamp.Valid() { + m.Timestamp = message.Timestamp + } + if message.Author.ID.Valid() { + m.Author = message.Author + } - // Sort (should be fast since it's presorted) - sort.Slice(ms, func(i, j int) bool { - return ms[i].ID > ms[j].ID - }) + ms[i] = m + return nil + } + } - if len(ms) > int(s.MaxMessages) { - ms = ms[len(ms)-int(s.MaxMessages):] + // Prepend the latest message at the end + + if len(ms) > 0 { + var end = s.MaxMessages() + if len(ms) < end { + end = len(ms) + } + + // Copy hack to prepend. This copies the 0th-(end-1)th entries to + // 1st-endth. + copy(ms[1:end], ms[0:end-1]) + // Then, set the 0th entry. + ms[0] = *message + + } else { + ms = append(ms, *message) } s.messages[message.ChannelID] = ms