diff --git a/discord/snowflake.go b/discord/snowflake.go index 6b14f67..030a828 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -6,12 +6,24 @@ import ( "time" ) -const DiscordEpoch = 1420070400000 * int64(time.Millisecond) +// DiscordEpoch is the Discord epoch constant in time.Duration (nanoseconds) +// since Unix epoch. +const DiscordEpoch = 1420070400000 * time.Millisecond + +// DurationSinceDiscordEpoch returns the duration from the Discord epoch to +// current. +func DurationSinceDiscordEpoch(t time.Time) time.Duration { + return time.Duration(t.UnixNano()) - DiscordEpoch +} type Snowflake int64 +// NullSnowflake gets encoded into a null. This is used for +// optional and nullable snowflake fields. +const NullSnowflake Snowflake = -1 + func NewSnowflake(t time.Time) Snowflake { - return Snowflake(TimeToDiscordEpoch(t) << 22) + return Snowflake((DurationSinceDiscordEpoch(t) / time.Millisecond) << 22) } func ParseSnowflake(sf string) (Snowflake, error) { @@ -38,19 +50,12 @@ func (s *Snowflake) UnmarshalJSON(v []byte) error { return nil } -func (s *Snowflake) MarshalJSON() ([]byte, error) { - var id string - - switch i := int64(*s); i { - case -1: // @me - id = "@me" - case 0: +func (s Snowflake) MarshalJSON() ([]byte, error) { + if s < 1 { return []byte("null"), nil - default: - id = strconv.FormatInt(i, 10) + } else { + return []byte(`"` + strconv.FormatInt(int64(s), 10) + `"`), nil } - - return []byte(`"` + id + `"`), nil } func (s Snowflake) String() string { @@ -62,7 +67,8 @@ func (s Snowflake) Valid() bool { } func (s Snowflake) Time() time.Time { - return time.Unix(0, int64(s)>>22*1000000+DiscordEpoch) + unixnano := ((time.Duration(s) >> 22) * time.Millisecond) + DiscordEpoch + return time.Unix(0, int64(unixnano)) } func (s Snowflake) Worker() uint8 { @@ -76,7 +82,3 @@ func (s Snowflake) PID() uint8 { func (s Snowflake) Increment() uint16 { return uint16(s & 0xFFF) } - -func TimeToDiscordEpoch(t time.Time) int64 { - return t.UnixNano()/int64(time.Millisecond) - DiscordEpoch -}