diff --git a/api/guild.go b/api/guild.go index f61f2eb..0f58ed2 100644 --- a/api/guild.go +++ b/api/guild.go @@ -329,7 +329,7 @@ type AuditLogData struct { // ActionType is the type of audit log event. ActionType discord.AuditLogEvent `schema:"action_type,omitempty"` // Before filters the log before a certain entry ID. - Before discord.Snowflake `schema:"before,omitempty"` + Before discord.AuditLogEntryID `schema:"before,omitempty"` // Limit limits how many entries are returned (default 50, minimum 1, // maximum 100). Limit uint `schema:"limit"` diff --git a/bot/extras/arguments/mention.go b/bot/extras/arguments/mention.go index 72c606f..bff4031 100644 --- a/bot/extras/arguments/mention.go +++ b/bot/extras/arguments/mention.go @@ -25,8 +25,8 @@ func (m *ChannelMention) Usage() string { return "#channel" } -func (m *ChannelMention) ID() discord.Snowflake { - return discord.Snowflake(*m) +func (m *ChannelMention) ID() discord.ChannelID { + return discord.ChannelID(*m) } func (m *ChannelMention) Mention() string { @@ -45,8 +45,8 @@ func (m *UserMention) Usage() string { return "@user" } -func (m *UserMention) ID() discord.Snowflake { - return discord.Snowflake(*m) +func (m *UserMention) ID() discord.UserID { + return discord.UserID(*m) } func (m *UserMention) Mention() string { @@ -65,8 +65,8 @@ func (m *RoleMention) Usage() string { return "@role" } -func (m *RoleMention) ID() discord.Snowflake { - return discord.Snowflake(*m) +func (m *RoleMention) ID() discord.RoleID { + return discord.RoleID(*m) } func (m *RoleMention) Mention() string { diff --git a/bot/extras/arguments/mention_test.go b/bot/extras/arguments/mention_test.go index bb5e2eb..59c80fb 100644 --- a/bot/extras/arguments/mention_test.go +++ b/bot/extras/arguments/mention_test.go @@ -2,44 +2,58 @@ package arguments import ( "testing" - - "github.com/diamondburned/arikawa/discord" ) -func TestMention(t *testing.T) { - var ( - c ChannelMention - u UserMention - r RoleMention - ) +func TestChannelMention(t *testing.T) { + test := new(ChannelMention) + str := "<#123123>" + id := 123123 - type mention interface { - Parse(arg string) error - ID() discord.Snowflake - Mention() string + if err := test.Parse(str); err != nil { + t.Fatal("Expected", id, "error:", err) } - var tests = []struct { - mention - str string - id discord.Snowflake - }{ - {&c, "<#123123>", 123123}, - {&r, "<@&23321>", 23321}, - {&u, "<@123123>", 123123}, + if id := test.ID(); id != id { + t.Fatal("Expected", id, "got", id) } - for _, test := range tests { - if err := test.Parse(test.str); err != nil { - t.Fatal("Expected", test.id, "error:", err) - } - - if id := test.ID(); id != test.id { - t.Fatal("Expected", test.id, "got", id) - } - - if mention := test.Mention(); mention != test.str { - t.Fatal("Expected", test.str, "got", mention) - } + if mention := test.Mention(); mention != str { + t.Fatal("Expected", str, "got", mention) + } +} + +func TestUserMention(t *testing.T) { + test := new(UserMention) + str := "<@123123>" + id := 123123 + + if err := test.Parse(str); err != nil { + t.Fatal("Expected", id, "error:", err) + } + + if id := test.ID(); id != id { + t.Fatal("Expected", id, "got", id) + } + + if mention := test.Mention(); mention != str { + t.Fatal("Expected", str, "got", mention) + } +} + +func TestRoleMention(t *testing.T) { + test := new(RoleMention) + str := "<@&123123>" + id := 123123 + + if err := test.Parse(str); err != nil { + t.Fatal("Expected", id, "error:", err) + } + + if id := test.ID(); id != id { + t.Fatal("Expected", id, "got", id) + } + + if mention := test.Mention(); mention != str { + t.Fatal("Expected", str, "got", mention) } } diff --git a/bot/extras/infer/infer.go b/bot/extras/infer/infer.go index d4e1767..f7d6a7d 100644 --- a/bot/extras/infer/infer.go +++ b/bot/extras/infer/infer.go @@ -67,7 +67,7 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake { if chID := reflectID(v.Field(i), thing); chID.IsValid() { return chID } - case reflect.Int64: + case reflect.Uint64: switch { case false, // Contains works with "LastMessageID" and such. @@ -75,7 +75,7 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake { // Special case where the struct name has Channel in it. field.Name == "ID" && strings.Contains(t.Name(), thing): - return discord.Snowflake(v.Field(i).Int()) + return discord.Snowflake(v.Field(i).Uint()) } } } diff --git a/discord/snowflake.go b/discord/snowflake.go index 125b595..8fef467 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -6,24 +6,23 @@ import ( "time" ) -// DiscordEpoch is the Discord epoch constant in time.Duration (nanoseconds) +// Epoch is the Discord epoch constant in time.Duration (nanoseconds) // since Unix epoch. -const DiscordEpoch = 1420070400000 * time.Millisecond +const Epoch = 1420070400000 * time.Millisecond -// DurationSinceDiscordEpoch returns the duration from the Discord epoch to -// current. -func DurationSinceDiscordEpoch(t time.Time) time.Duration { - return time.Duration(t.UnixNano()) - DiscordEpoch +// DurationSinceEpoch returns the duration from the Discord epoch to current. +func DurationSinceEpoch(t time.Time) time.Duration { + return time.Duration(t.UnixNano()) - Epoch } -type Snowflake int64 +type Snowflake uint64 // NullSnowflake gets encoded into a null. This is used for // optional and nullable snowflake fields. -const NullSnowflake Snowflake = -1 +const NullSnowflake = ^Snowflake(0) func NewSnowflake(t time.Time) Snowflake { - return Snowflake((DurationSinceDiscordEpoch(t) / time.Millisecond) << 22) + return Snowflake((DurationSinceEpoch(t) / time.Millisecond) << 22) } func ParseSnowflake(sf string) (Snowflake, error) { @@ -70,7 +69,7 @@ func (s Snowflake) String() string { // IsValid returns whether or not the snowflake is valid. func (s Snowflake) IsValid() bool { - return int64(s) > 0 + return !(int64(s) == 0 || s == NullSnowflake) } // IsNull returns whether or not the snowflake is null. @@ -79,7 +78,7 @@ func (s Snowflake) IsNull() bool { } func (s Snowflake) Time() time.Time { - unixnano := ((time.Duration(s) >> 22) * time.Millisecond) + DiscordEpoch + unixnano := time.Duration(s>>22)*time.Millisecond + Epoch return time.Unix(0, int64(unixnano)) } @@ -97,6 +96,8 @@ func (s Snowflake) Increment() uint16 { type AppID Snowflake +const NullAppID = AppID(NullSnowflake) + func (s AppID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *AppID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s AppID) String() string { return Snowflake(s).String() } @@ -109,6 +110,8 @@ func (s AppID) Increment() uint16 { return Snowflake(s).Increment() type AttachmentID Snowflake +const NullAttachmentID = AttachmentID(NullSnowflake) + func (s AttachmentID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *AttachmentID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s AttachmentID) String() string { return Snowflake(s).String() } @@ -121,6 +124,8 @@ func (s AttachmentID) Increment() uint16 { return Snowflake(s).Incre type AuditLogEntryID Snowflake +const NullAuditLogEntryID = AuditLogEntryID(NullSnowflake) + func (s AuditLogEntryID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *AuditLogEntryID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s AuditLogEntryID) String() string { return Snowflake(s).String() } @@ -133,6 +138,8 @@ func (s AuditLogEntryID) Increment() uint16 { return Snowflake(s).In type ChannelID Snowflake +const NullChannelID = ChannelID(NullSnowflake) + func (s ChannelID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *ChannelID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s ChannelID) String() string { return Snowflake(s).String() } @@ -145,6 +152,8 @@ func (s ChannelID) Increment() uint16 { return Snowflake(s).Incremen type EmojiID Snowflake +const NullEmojiID = EmojiID(NullSnowflake) + func (s EmojiID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *EmojiID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s EmojiID) String() string { return Snowflake(s).String() } @@ -157,6 +166,8 @@ func (s EmojiID) Increment() uint16 { return Snowflake(s).Increment( type IntegrationID Snowflake +const NullIntegrationID = IntegrationID(NullSnowflake) + func (s IntegrationID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *IntegrationID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s IntegrationID) String() string { return Snowflake(s).String() } @@ -169,6 +180,8 @@ func (s IntegrationID) Increment() uint16 { return Snowflake(s).Incr type GuildID Snowflake +const NullGuildID = GuildID(NullSnowflake) + func (s GuildID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *GuildID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s GuildID) String() string { return Snowflake(s).String() } @@ -181,6 +194,8 @@ func (s GuildID) Increment() uint16 { return Snowflake(s).Increment( type MessageID Snowflake +const NullMessageID = MessageID(NullSnowflake) + func (s MessageID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *MessageID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s MessageID) String() string { return Snowflake(s).String() } @@ -193,6 +208,8 @@ func (s MessageID) Increment() uint16 { return Snowflake(s).Incremen type RoleID Snowflake +const NullRoleID = RoleID(NullSnowflake) + func (s RoleID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *RoleID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s RoleID) String() string { return Snowflake(s).String() } @@ -205,6 +222,8 @@ func (s RoleID) Increment() uint16 { return Snowflake(s).Increment() type UserID Snowflake +const NullUserID = UserID(NullSnowflake) + func (s UserID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *UserID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s UserID) String() string { return Snowflake(s).String() } @@ -217,6 +236,8 @@ func (s UserID) Increment() uint16 { return Snowflake(s).Increment() type WebhookID Snowflake +const NullWebhookID = WebhookID(NullSnowflake) + func (s WebhookID) MarshalJSON() ([]byte, error) { return Snowflake(s).MarshalJSON() } func (s *WebhookID) UnmarshalJSON(v []byte) error { return (*Snowflake)(s).UnmarshalJSON(v) } func (s WebhookID) String() string { return Snowflake(s).String() } diff --git a/discord/snowflake_test.go b/discord/snowflake_test.go index 6cfc3c9..a3272e6 100644 --- a/discord/snowflake_test.go +++ b/discord/snowflake_test.go @@ -36,6 +36,28 @@ func TestSnowflake(t *testing.T) { } }) + t.Run("IsValid", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + if Snowflake(0).IsValid() { + t.Fatal("0 isn't a valid Snowflake") + } + }) + + t.Run("null", func(t *testing.T) { + if NullSnowflake.IsValid() { + t.Fatal("NullSnowflake isn't a valid Snowflake") + } + }) + + t.Run("valid", func(t *testing.T) { + var testFlake Snowflake = 123 + + if !testFlake.IsValid() { + t.Fatal(testFlake, "is a valid Snowflake") + } + }) + }) + t.Run("new", func(t *testing.T) { if s := NewSnowflake(expect); !s.Time().Equal(expect) { t.Fatal("Unexpected new snowflake from expected time:", s) diff --git a/voice/session.go b/voice/session.go index 59a2b3a..291aadc 100644 --- a/voice/session.go +++ b/voice/session.go @@ -138,7 +138,7 @@ func (s *Session) JoinChannelCtx(ctx context.Context, gID discord.GuildID, cID d s.speaking = false // Ensure that if `cID` is zero that it passes null to the update event. - var channelID discord.ChannelID = -1 + channelID := discord.NullChannelID if cID.IsValid() { channelID = cID }