mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 10:43:30 +00:00
Fixed bug where message state would screw up
This commit is contained in:
parent
0978d513c4
commit
27e315ca66
|
@ -38,6 +38,10 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
|
|||
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
|
||||
}
|
||||
|
||||
func (t Timestamp) Valid() bool {
|
||||
return !time.Time(t).IsZero()
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
type UnixTimestamp int64
|
||||
|
|
|
@ -161,9 +161,9 @@ type (
|
|||
// Clients may only update their game status 5 times per 20 seconds.
|
||||
PresenceUpdateEvent discord.Presence
|
||||
TypingStartEvent struct {
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
Timestamp discord.Timestamp `json:"timestamp"`
|
||||
ChannelID discord.Snowflake `json:"channel_id"`
|
||||
UserID discord.Snowflake `json:"user_id"`
|
||||
Timestamp discord.UnixTimestamp `json:"timestamp"`
|
||||
|
||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||
Member *discord.Member `json:"member,omitempty"`
|
||||
|
|
|
@ -12,7 +12,7 @@ import (
|
|||
)
|
||||
|
||||
var WSBuffer = 12
|
||||
var WSReadLimit = 4096 // 4096 bytes
|
||||
var WSReadLimit int64 = 8192000 // 8 MiB
|
||||
|
||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||
// This connection expects the driver to handle compression by itself.
|
||||
|
@ -64,6 +64,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
HTTPHeader: headers,
|
||||
})
|
||||
|
||||
c.Conn.SetReadLimit(WSReadLimit)
|
||||
|
||||
go func() {
|
||||
c.readLoop(c.events)
|
||||
}()
|
||||
|
@ -109,6 +111,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
// Probably a zlib payload
|
||||
z, err := zlib.NewReader(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
return nil,
|
||||
errors.Wrap(err, "Failed to create a zlib reader")
|
||||
}
|
||||
|
@ -117,24 +120,18 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
|||
r = z
|
||||
}
|
||||
|
||||
return ioutil.ReadAll(r)
|
||||
b, err := ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
c.CloseRead(ctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||
// TODO: zlib stream
|
||||
|
||||
w, err := c.Writer(ctx, websocket.MessageText)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to get WS writer")
|
||||
}
|
||||
|
||||
defer w.Close()
|
||||
|
||||
// Compress with zlib by default NOT.
|
||||
// w = zlib.NewWriter(w)
|
||||
|
||||
_, err = w.Write(b)
|
||||
return err
|
||||
return c.Write(ctx, websocket.MessageText, b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close(err error) error {
|
||||
|
|
|
@ -4,6 +4,7 @@ package state
|
|||
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
|
@ -37,6 +38,11 @@ type State struct {
|
|||
PreHandler *handler.Handler // default nil
|
||||
|
||||
unhooker func()
|
||||
|
||||
// List of channels with few messages, so it doesn't bother hitting the API
|
||||
// again.
|
||||
fewMessages []discord.Snowflake
|
||||
fewMutex sync.Mutex
|
||||
}
|
||||
|
||||
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
||||
|
@ -298,9 +304,28 @@ func (s *State) Message(
|
|||
// Messages fetches maximum 100 messages from the API, if it has to. There is no
|
||||
// limit if it's from the State storage.
|
||||
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
|
||||
// TODO: Think of a design that doesn't rely on MaxMessages().
|
||||
var maxMsgs = s.MaxMessages()
|
||||
|
||||
ms, err := s.Store.Messages(channelID)
|
||||
if err == nil {
|
||||
return ms, nil
|
||||
// If the state already has as many messages as it can, skip the API.
|
||||
if maxMsgs <= len(ms) {
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// Is the channel tiny?
|
||||
s.fewMutex.Lock()
|
||||
for _, ch := range s.fewMessages {
|
||||
if ch == channelID {
|
||||
// Yes, skip the state.
|
||||
s.fewMutex.Unlock()
|
||||
return ms, nil
|
||||
}
|
||||
}
|
||||
|
||||
// No, fetch from the state.
|
||||
s.fewMutex.Unlock()
|
||||
}
|
||||
|
||||
ms, err = s.Session.Messages(channelID, 100)
|
||||
|
@ -314,7 +339,18 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
|||
}
|
||||
}
|
||||
|
||||
return ms, nil
|
||||
if len(ms) < maxMsgs {
|
||||
// Tiny channel, store this.
|
||||
s.fewMutex.Lock()
|
||||
s.fewMessages = append(s.fewMessages, channelID)
|
||||
s.fewMutex.Unlock()
|
||||
|
||||
return ms, nil
|
||||
}
|
||||
|
||||
// Since the latest messages are at the end and we already know the maxMsgs,
|
||||
// we could slice this right away.
|
||||
return ms[:maxMsgs], nil
|
||||
}
|
||||
|
||||
////
|
||||
|
|
|
@ -34,6 +34,7 @@ type StoreGetter interface {
|
|||
|
||||
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
|
||||
Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
||||
MaxMessages() int // used to know if the state is filled or not.
|
||||
|
||||
// These don't get fetched from the API, it's Gateway only.
|
||||
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
|
||||
|
|
|
@ -148,6 +148,11 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
|
|||
|
||||
for i, ch := range chs {
|
||||
if ch.ID == channel.ID {
|
||||
// Also from discordgo.
|
||||
if channel.Permissions == nil {
|
||||
channel.Permissions = ch.Permissions
|
||||
}
|
||||
|
||||
// Found, just edit
|
||||
chs[i] = *channel
|
||||
|
||||
|
@ -289,11 +294,21 @@ func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
|
|||
return gs, nil
|
||||
}
|
||||
|
||||
func (s *DefaultStore) GuildSet(g *discord.Guild) error {
|
||||
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
|
||||
s.mut.Lock()
|
||||
s.guilds[g.ID] = g
|
||||
s.mut.Unlock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
if g, ok := s.guilds[guild.ID]; ok {
|
||||
// preserve state stuff
|
||||
if guild.Roles == nil {
|
||||
guild.Roles = g.Roles
|
||||
}
|
||||
if guild.Emojis == nil {
|
||||
guild.Emojis = g.Emojis
|
||||
}
|
||||
}
|
||||
|
||||
s.guilds[guild.ID] = guild
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -425,25 +440,66 @@ func (s *DefaultStore) Messages(
|
|||
return ms, nil
|
||||
}
|
||||
|
||||
func (s *DefaultStore) MaxMessages() int {
|
||||
return int(s.DefaultStoreOptions.MaxMessages)
|
||||
}
|
||||
|
||||
func (s *DefaultStore) MessageSet(message *discord.Message) error {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
ms, ok := s.messages[message.ChannelID]
|
||||
if !ok {
|
||||
ms = make([]discord.Message, 0, int(s.MaxMessages)+1)
|
||||
ms = make([]discord.Message, 0, s.MaxMessages()+1)
|
||||
}
|
||||
|
||||
// Append
|
||||
ms = append(ms, *message)
|
||||
// Check if we already have the message.
|
||||
for i, m := range ms {
|
||||
if m.ID == message.ID {
|
||||
// Thanks, Discord.
|
||||
if message.Content != "" {
|
||||
m.Content = message.Content
|
||||
}
|
||||
if message.EditedTimestamp != nil {
|
||||
m.EditedTimestamp = message.EditedTimestamp
|
||||
}
|
||||
if message.Mentions != nil {
|
||||
m.Mentions = message.Mentions
|
||||
}
|
||||
if message.Embeds != nil {
|
||||
m.Embeds = message.Embeds
|
||||
}
|
||||
if message.Attachments != nil {
|
||||
m.Attachments = message.Attachments
|
||||
}
|
||||
if message.Timestamp.Valid() {
|
||||
m.Timestamp = message.Timestamp
|
||||
}
|
||||
if message.Author.ID.Valid() {
|
||||
m.Author = message.Author
|
||||
}
|
||||
|
||||
// Sort (should be fast since it's presorted)
|
||||
sort.Slice(ms, func(i, j int) bool {
|
||||
return ms[i].ID > ms[j].ID
|
||||
})
|
||||
ms[i] = m
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(ms) > int(s.MaxMessages) {
|
||||
ms = ms[len(ms)-int(s.MaxMessages):]
|
||||
// Prepend the latest message at the end
|
||||
|
||||
if len(ms) > 0 {
|
||||
var end = s.MaxMessages()
|
||||
if len(ms) < end {
|
||||
end = len(ms)
|
||||
}
|
||||
|
||||
// Copy hack to prepend. This copies the 0th-(end-1)th entries to
|
||||
// 1st-endth.
|
||||
copy(ms[1:end], ms[0:end-1])
|
||||
// Then, set the 0th entry.
|
||||
ms[0] = *message
|
||||
|
||||
} else {
|
||||
ms = append(ms, *message)
|
||||
}
|
||||
|
||||
s.messages[message.ChannelID] = ms
|
||||
|
|
Loading…
Reference in a new issue