diff --git a/bot/ctx.go b/bot/ctx.go index 4fc49dc..742004c 100644 --- a/bot/ctx.go +++ b/bot/ctx.go @@ -143,6 +143,17 @@ func New(s *state.State, cmd interface{}) (*Context, error) { return ctx, nil } +func (ctx *Context) MustRegisterSubcommand(cmd interface{}) *Subcommand { + s, err := ctx.RegisterSubcommand(cmd) + if err != nil { + panic(err) + } + + return s +} + +// RegisterSubcommand registers and adds cmd to the list of subcommands. It will +// also return the resulting Subcommand. func (ctx *Context) RegisterSubcommand(cmd interface{}) (*Subcommand, error) { s, err := NewSubcommand(cmd) if err != nil { diff --git a/bot/extras/arguments/mention.go b/bot/extras/arguments/mention.go index e04261b..53ffa6b 100644 --- a/bot/extras/arguments/mention.go +++ b/bot/extras/arguments/mention.go @@ -3,6 +3,8 @@ package arguments import ( "errors" "regexp" + + "github.com/diamondburned/arikawa/discord" ) var ( @@ -11,42 +13,52 @@ var ( RoleRegex = regexp.MustCompile(`<@&(\d+)>`) ) -type ChannelMention string +type ChannelMention discord.Snowflake func (m *ChannelMention) Parse(arg string) error { - return grabFirst(ChannelRegex, "channel mention", arg, (*string)(m)) + return grabFirst(ChannelRegex, "channel mention", + arg, (*discord.Snowflake)(m)) } func (m *ChannelMention) Usage() string { return "#channel" } -type UserMention string +type UserMention discord.Snowflake func (m *UserMention) Parse(arg string) error { - return grabFirst(UserRegex, "user mention", arg, (*string)(m)) + return grabFirst(UserRegex, "user mention", + arg, (*discord.Snowflake)(m)) } func (m *UserMention) Usage() string { return "@user" } -type RoleMention string +type RoleMention discord.Snowflake func (m *RoleMention) Parse(arg string) error { - return grabFirst(RoleRegex, "role mention", arg, (*string)(m)) + return grabFirst(RoleRegex, "role mention", + arg, (*discord.Snowflake)(m)) } func (m *RoleMention) Usage() string { return "@role" } -func grabFirst(reg *regexp.Regexp, item, input string, output *string) error { +func grabFirst(reg *regexp.Regexp, + item, input string, output *discord.Snowflake) error { + matches := reg.FindStringSubmatch(input) if len(matches) < 2 { return errors.New("Invalid " + item) } - *output = matches[1] + id, err := discord.ParseSnowflake(matches[1]) + if err != nil { + return errors.New("Invalid " + item) + } + + *output = id return nil } diff --git a/discord/snowflake.go b/discord/snowflake.go index 3febc47..6b14f67 100644 --- a/discord/snowflake.go +++ b/discord/snowflake.go @@ -14,6 +14,15 @@ func NewSnowflake(t time.Time) Snowflake { return Snowflake(TimeToDiscordEpoch(t) << 22) } +func ParseSnowflake(sf string) (Snowflake, error) { + i, err := strconv.ParseInt(sf, 10, 64) + if err != nil { + return 0, err + } + + return Snowflake(i), nil +} + func (s *Snowflake) UnmarshalJSON(v []byte) error { id := strings.Trim(string(v), `"`) if id == "null" {