1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-30 10:43:30 +00:00

Fixed bug where message state would screw up

This commit is contained in:
diamondburned (Forefront) 2020-01-20 00:53:23 -08:00
parent 0978d513c4
commit 27e315ca66
6 changed files with 126 additions and 32 deletions

View file

@ -38,6 +38,10 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
}
func (t Timestamp) Valid() bool {
return !time.Time(t).IsZero()
}
//
type UnixTimestamp int64

View file

@ -161,9 +161,9 @@ type (
// Clients may only update their game status 5 times per 20 seconds.
PresenceUpdateEvent discord.Presence
TypingStartEvent struct {
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
Timestamp discord.Timestamp `json:"timestamp"`
ChannelID discord.Snowflake `json:"channel_id"`
UserID discord.Snowflake `json:"user_id"`
Timestamp discord.UnixTimestamp `json:"timestamp"`
GuildID discord.Snowflake `json:"guild_id,omitempty"`
Member *discord.Member `json:"member,omitempty"`

View file

@ -12,7 +12,7 @@ import (
)
var WSBuffer = 12
var WSReadLimit = 4096 // 4096 bytes
var WSReadLimit int64 = 8192000 // 8 MiB
// Connection is an interface that abstracts around a generic Websocket driver.
// This connection expects the driver to handle compression by itself.
@ -64,6 +64,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
HTTPHeader: headers,
})
c.Conn.SetReadLimit(WSReadLimit)
go func() {
c.readLoop(c.events)
}()
@ -109,6 +111,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
// Probably a zlib payload
z, err := zlib.NewReader(r)
if err != nil {
c.CloseRead(ctx)
return nil,
errors.Wrap(err, "Failed to create a zlib reader")
}
@ -117,24 +120,18 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
r = z
}
return ioutil.ReadAll(r)
b, err := ioutil.ReadAll(r)
if err != nil {
c.CloseRead(ctx)
return nil, err
}
return b, nil
}
func (c *Conn) Send(ctx context.Context, b []byte) error {
// TODO: zlib stream
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return errors.Wrap(err, "Failed to get WS writer")
}
defer w.Close()
// Compress with zlib by default NOT.
// w = zlib.NewWriter(w)
_, err = w.Write(b)
return err
return c.Write(ctx, websocket.MessageText, b)
}
func (c *Conn) Close(err error) error {

View file

@ -4,6 +4,7 @@ package state
import (
"log"
"sync"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
@ -37,6 +38,11 @@ type State struct {
PreHandler *handler.Handler // default nil
unhooker func()
// List of channels with few messages, so it doesn't bother hitting the API
// again.
fewMessages []discord.Snowflake
fewMutex sync.Mutex
}
func NewFromSession(s *session.Session, store Store) (*State, error) {
@ -298,9 +304,28 @@ func (s *State) Message(
// Messages fetches maximum 100 messages from the API, if it has to. There is no
// limit if it's from the State storage.
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
// TODO: Think of a design that doesn't rely on MaxMessages().
var maxMsgs = s.MaxMessages()
ms, err := s.Store.Messages(channelID)
if err == nil {
return ms, nil
// If the state already has as many messages as it can, skip the API.
if maxMsgs <= len(ms) {
return ms, nil
}
// Is the channel tiny?
s.fewMutex.Lock()
for _, ch := range s.fewMessages {
if ch == channelID {
// Yes, skip the state.
s.fewMutex.Unlock()
return ms, nil
}
}
// No, fetch from the state.
s.fewMutex.Unlock()
}
ms, err = s.Session.Messages(channelID, 100)
@ -314,7 +339,18 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
}
}
return ms, nil
if len(ms) < maxMsgs {
// Tiny channel, store this.
s.fewMutex.Lock()
s.fewMessages = append(s.fewMessages, channelID)
s.fewMutex.Unlock()
return ms, nil
}
// Since the latest messages are at the end and we already know the maxMsgs,
// we could slice this right away.
return ms[:maxMsgs], nil
}
////

View file

@ -34,6 +34,7 @@ type StoreGetter interface {
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
Messages(channelID discord.Snowflake) ([]discord.Message, error)
MaxMessages() int // used to know if the state is filled or not.
// These don't get fetched from the API, it's Gateway only.
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)

View file

@ -148,6 +148,11 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
for i, ch := range chs {
if ch.ID == channel.ID {
// Also from discordgo.
if channel.Permissions == nil {
channel.Permissions = ch.Permissions
}
// Found, just edit
chs[i] = *channel
@ -289,11 +294,21 @@ func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
return gs, nil
}
func (s *DefaultStore) GuildSet(g *discord.Guild) error {
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
s.mut.Lock()
s.guilds[g.ID] = g
s.mut.Unlock()
defer s.mut.Unlock()
if g, ok := s.guilds[guild.ID]; ok {
// preserve state stuff
if guild.Roles == nil {
guild.Roles = g.Roles
}
if guild.Emojis == nil {
guild.Emojis = g.Emojis
}
}
s.guilds[guild.ID] = guild
return nil
}
@ -425,25 +440,66 @@ func (s *DefaultStore) Messages(
return ms, nil
}
func (s *DefaultStore) MaxMessages() int {
return int(s.DefaultStoreOptions.MaxMessages)
}
func (s *DefaultStore) MessageSet(message *discord.Message) error {
s.mut.Lock()
defer s.mut.Unlock()
ms, ok := s.messages[message.ChannelID]
if !ok {
ms = make([]discord.Message, 0, int(s.MaxMessages)+1)
ms = make([]discord.Message, 0, s.MaxMessages()+1)
}
// Append
ms = append(ms, *message)
// Check if we already have the message.
for i, m := range ms {
if m.ID == message.ID {
// Thanks, Discord.
if message.Content != "" {
m.Content = message.Content
}
if message.EditedTimestamp != nil {
m.EditedTimestamp = message.EditedTimestamp
}
if message.Mentions != nil {
m.Mentions = message.Mentions
}
if message.Embeds != nil {
m.Embeds = message.Embeds
}
if message.Attachments != nil {
m.Attachments = message.Attachments
}
if message.Timestamp.Valid() {
m.Timestamp = message.Timestamp
}
if message.Author.ID.Valid() {
m.Author = message.Author
}
// Sort (should be fast since it's presorted)
sort.Slice(ms, func(i, j int) bool {
return ms[i].ID > ms[j].ID
})
ms[i] = m
return nil
}
}
if len(ms) > int(s.MaxMessages) {
ms = ms[len(ms)-int(s.MaxMessages):]
// Prepend the latest message at the end
if len(ms) > 0 {
var end = s.MaxMessages()
if len(ms) < end {
end = len(ms)
}
// Copy hack to prepend. This copies the 0th-(end-1)th entries to
// 1st-endth.
copy(ms[1:end], ms[0:end-1])
// Then, set the 0th entry.
ms[0] = *message
} else {
ms = append(ms, *message)
}
s.messages[message.ChannelID] = ms