mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-12-01 08:37:23 +00:00
Fixed bug where message state would screw up
This commit is contained in:
parent
0978d513c4
commit
27e315ca66
|
|
@ -38,6 +38,10 @@ func (t Timestamp) MarshalJSON() ([]byte, error) {
|
||||||
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
|
return []byte(`"` + time.Time(t).Format(TimestampFormat) + `"`), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t Timestamp) Valid() bool {
|
||||||
|
return !time.Time(t).IsZero()
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
||||||
type UnixTimestamp int64
|
type UnixTimestamp int64
|
||||||
|
|
|
||||||
|
|
@ -161,9 +161,9 @@ type (
|
||||||
// Clients may only update their game status 5 times per 20 seconds.
|
// Clients may only update their game status 5 times per 20 seconds.
|
||||||
PresenceUpdateEvent discord.Presence
|
PresenceUpdateEvent discord.Presence
|
||||||
TypingStartEvent struct {
|
TypingStartEvent struct {
|
||||||
ChannelID discord.Snowflake `json:"channel_id"`
|
ChannelID discord.Snowflake `json:"channel_id"`
|
||||||
UserID discord.Snowflake `json:"user_id"`
|
UserID discord.Snowflake `json:"user_id"`
|
||||||
Timestamp discord.Timestamp `json:"timestamp"`
|
Timestamp discord.UnixTimestamp `json:"timestamp"`
|
||||||
|
|
||||||
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
GuildID discord.Snowflake `json:"guild_id,omitempty"`
|
||||||
Member *discord.Member `json:"member,omitempty"`
|
Member *discord.Member `json:"member,omitempty"`
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var WSBuffer = 12
|
var WSBuffer = 12
|
||||||
var WSReadLimit = 4096 // 4096 bytes
|
var WSReadLimit int64 = 8192000 // 8 MiB
|
||||||
|
|
||||||
// Connection is an interface that abstracts around a generic Websocket driver.
|
// Connection is an interface that abstracts around a generic Websocket driver.
|
||||||
// This connection expects the driver to handle compression by itself.
|
// This connection expects the driver to handle compression by itself.
|
||||||
|
|
@ -64,6 +64,8 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
||||||
HTTPHeader: headers,
|
HTTPHeader: headers,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
c.Conn.SetReadLimit(WSReadLimit)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
c.readLoop(c.events)
|
c.readLoop(c.events)
|
||||||
}()
|
}()
|
||||||
|
|
@ -109,6 +111,7 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
// Probably a zlib payload
|
// Probably a zlib payload
|
||||||
z, err := zlib.NewReader(r)
|
z, err := zlib.NewReader(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.CloseRead(ctx)
|
||||||
return nil,
|
return nil,
|
||||||
errors.Wrap(err, "Failed to create a zlib reader")
|
errors.Wrap(err, "Failed to create a zlib reader")
|
||||||
}
|
}
|
||||||
|
|
@ -117,24 +120,18 @@ func (c *Conn) readAll(ctx context.Context) ([]byte, error) {
|
||||||
r = z
|
r = z
|
||||||
}
|
}
|
||||||
|
|
||||||
return ioutil.ReadAll(r)
|
b, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
c.CloseRead(ctx)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
func (c *Conn) Send(ctx context.Context, b []byte) error {
|
||||||
// TODO: zlib stream
|
// TODO: zlib stream
|
||||||
|
return c.Write(ctx, websocket.MessageText, b)
|
||||||
w, err := c.Writer(ctx, websocket.MessageText)
|
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "Failed to get WS writer")
|
|
||||||
}
|
|
||||||
|
|
||||||
defer w.Close()
|
|
||||||
|
|
||||||
// Compress with zlib by default NOT.
|
|
||||||
// w = zlib.NewWriter(w)
|
|
||||||
|
|
||||||
_, err = w.Write(b)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close(err error) error {
|
func (c *Conn) Close(err error) error {
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ package state
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/discord"
|
"github.com/diamondburned/arikawa/discord"
|
||||||
"github.com/diamondburned/arikawa/gateway"
|
"github.com/diamondburned/arikawa/gateway"
|
||||||
|
|
@ -37,6 +38,11 @@ type State struct {
|
||||||
PreHandler *handler.Handler // default nil
|
PreHandler *handler.Handler // default nil
|
||||||
|
|
||||||
unhooker func()
|
unhooker func()
|
||||||
|
|
||||||
|
// List of channels with few messages, so it doesn't bother hitting the API
|
||||||
|
// again.
|
||||||
|
fewMessages []discord.Snowflake
|
||||||
|
fewMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
func NewFromSession(s *session.Session, store Store) (*State, error) {
|
||||||
|
|
@ -298,9 +304,28 @@ func (s *State) Message(
|
||||||
// Messages fetches maximum 100 messages from the API, if it has to. There is no
|
// Messages fetches maximum 100 messages from the API, if it has to. There is no
|
||||||
// limit if it's from the State storage.
|
// limit if it's from the State storage.
|
||||||
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
|
func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error) {
|
||||||
|
// TODO: Think of a design that doesn't rely on MaxMessages().
|
||||||
|
var maxMsgs = s.MaxMessages()
|
||||||
|
|
||||||
ms, err := s.Store.Messages(channelID)
|
ms, err := s.Store.Messages(channelID)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return ms, nil
|
// If the state already has as many messages as it can, skip the API.
|
||||||
|
if maxMsgs <= len(ms) {
|
||||||
|
return ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is the channel tiny?
|
||||||
|
s.fewMutex.Lock()
|
||||||
|
for _, ch := range s.fewMessages {
|
||||||
|
if ch == channelID {
|
||||||
|
// Yes, skip the state.
|
||||||
|
s.fewMutex.Unlock()
|
||||||
|
return ms, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// No, fetch from the state.
|
||||||
|
s.fewMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
ms, err = s.Session.Messages(channelID, 100)
|
ms, err = s.Session.Messages(channelID, 100)
|
||||||
|
|
@ -314,7 +339,18 @@ func (s *State) Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return ms, nil
|
if len(ms) < maxMsgs {
|
||||||
|
// Tiny channel, store this.
|
||||||
|
s.fewMutex.Lock()
|
||||||
|
s.fewMessages = append(s.fewMessages, channelID)
|
||||||
|
s.fewMutex.Unlock()
|
||||||
|
|
||||||
|
return ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Since the latest messages are at the end and we already know the maxMsgs,
|
||||||
|
// we could slice this right away.
|
||||||
|
return ms[:maxMsgs], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
////
|
////
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ type StoreGetter interface {
|
||||||
|
|
||||||
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
|
Message(channelID, messageID discord.Snowflake) (*discord.Message, error)
|
||||||
Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
Messages(channelID discord.Snowflake) ([]discord.Message, error)
|
||||||
|
MaxMessages() int // used to know if the state is filled or not.
|
||||||
|
|
||||||
// These don't get fetched from the API, it's Gateway only.
|
// These don't get fetched from the API, it's Gateway only.
|
||||||
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
|
Presence(guildID, userID discord.Snowflake) (*discord.Presence, error)
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,11 @@ func (s *DefaultStore) ChannelSet(channel *discord.Channel) error {
|
||||||
|
|
||||||
for i, ch := range chs {
|
for i, ch := range chs {
|
||||||
if ch.ID == channel.ID {
|
if ch.ID == channel.ID {
|
||||||
|
// Also from discordgo.
|
||||||
|
if channel.Permissions == nil {
|
||||||
|
channel.Permissions = ch.Permissions
|
||||||
|
}
|
||||||
|
|
||||||
// Found, just edit
|
// Found, just edit
|
||||||
chs[i] = *channel
|
chs[i] = *channel
|
||||||
|
|
||||||
|
|
@ -289,11 +294,21 @@ func (s *DefaultStore) Guilds() ([]discord.Guild, error) {
|
||||||
return gs, nil
|
return gs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultStore) GuildSet(g *discord.Guild) error {
|
func (s *DefaultStore) GuildSet(guild *discord.Guild) error {
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
s.guilds[g.ID] = g
|
defer s.mut.Unlock()
|
||||||
s.mut.Unlock()
|
|
||||||
|
|
||||||
|
if g, ok := s.guilds[guild.ID]; ok {
|
||||||
|
// preserve state stuff
|
||||||
|
if guild.Roles == nil {
|
||||||
|
guild.Roles = g.Roles
|
||||||
|
}
|
||||||
|
if guild.Emojis == nil {
|
||||||
|
guild.Emojis = g.Emojis
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.guilds[guild.ID] = guild
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -425,25 +440,66 @@ func (s *DefaultStore) Messages(
|
||||||
return ms, nil
|
return ms, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultStore) MaxMessages() int {
|
||||||
|
return int(s.DefaultStoreOptions.MaxMessages)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultStore) MessageSet(message *discord.Message) error {
|
func (s *DefaultStore) MessageSet(message *discord.Message) error {
|
||||||
s.mut.Lock()
|
s.mut.Lock()
|
||||||
defer s.mut.Unlock()
|
defer s.mut.Unlock()
|
||||||
|
|
||||||
ms, ok := s.messages[message.ChannelID]
|
ms, ok := s.messages[message.ChannelID]
|
||||||
if !ok {
|
if !ok {
|
||||||
ms = make([]discord.Message, 0, int(s.MaxMessages)+1)
|
ms = make([]discord.Message, 0, s.MaxMessages()+1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Append
|
// Check if we already have the message.
|
||||||
ms = append(ms, *message)
|
for i, m := range ms {
|
||||||
|
if m.ID == message.ID {
|
||||||
|
// Thanks, Discord.
|
||||||
|
if message.Content != "" {
|
||||||
|
m.Content = message.Content
|
||||||
|
}
|
||||||
|
if message.EditedTimestamp != nil {
|
||||||
|
m.EditedTimestamp = message.EditedTimestamp
|
||||||
|
}
|
||||||
|
if message.Mentions != nil {
|
||||||
|
m.Mentions = message.Mentions
|
||||||
|
}
|
||||||
|
if message.Embeds != nil {
|
||||||
|
m.Embeds = message.Embeds
|
||||||
|
}
|
||||||
|
if message.Attachments != nil {
|
||||||
|
m.Attachments = message.Attachments
|
||||||
|
}
|
||||||
|
if message.Timestamp.Valid() {
|
||||||
|
m.Timestamp = message.Timestamp
|
||||||
|
}
|
||||||
|
if message.Author.ID.Valid() {
|
||||||
|
m.Author = message.Author
|
||||||
|
}
|
||||||
|
|
||||||
// Sort (should be fast since it's presorted)
|
ms[i] = m
|
||||||
sort.Slice(ms, func(i, j int) bool {
|
return nil
|
||||||
return ms[i].ID > ms[j].ID
|
}
|
||||||
})
|
}
|
||||||
|
|
||||||
if len(ms) > int(s.MaxMessages) {
|
// Prepend the latest message at the end
|
||||||
ms = ms[len(ms)-int(s.MaxMessages):]
|
|
||||||
|
if len(ms) > 0 {
|
||||||
|
var end = s.MaxMessages()
|
||||||
|
if len(ms) < end {
|
||||||
|
end = len(ms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy hack to prepend. This copies the 0th-(end-1)th entries to
|
||||||
|
// 1st-endth.
|
||||||
|
copy(ms[1:end], ms[0:end-1])
|
||||||
|
// Then, set the 0th entry.
|
||||||
|
ms[0] = *message
|
||||||
|
|
||||||
|
} else {
|
||||||
|
ms = append(ms, *message)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.messages[message.ChannelID] = ms
|
s.messages[message.ChannelID] = ms
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue