Bot: Added package infer for getting IDs from unknown structs

This commit is contained in:
diamondburned (Forefront) 2020-05-05 23:15:25 -07:00
parent 9219d2fc40
commit 7dbdc78d67
5 changed files with 172 additions and 145 deletions

View File

@ -5,6 +5,7 @@ import (
"strings"
"github.com/diamondburned/arikawa/api"
"github.com/diamondburned/arikawa/bot/extras/infer"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/pkg/errors"
@ -441,12 +442,12 @@ func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
return **is
}
var channelID = reflectChannelID(ev)
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
var userID = reflectUserID(ev)
var userID = infer.UserID(ev)
if !userID.Valid() {
return false
}
@ -467,7 +468,7 @@ func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
return **is
}
var channelID = reflectChannelID(ev)
var channelID = infer.ChannelID(ev)
if !channelID.Valid() {
return false
}
@ -524,68 +525,3 @@ func errorReturns(returns []reflect.Value) (interface{}, error) {
// Treat the last return as an error.
return nil, v.(error)
}
func reflectChannelID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Channel")
}
func reflectGuildID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "Guild")
}
func reflectUserID(_struct interface{}) discord.Snowflake {
return _reflectID(reflect.ValueOf(_struct), "User")
}
func _reflectID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
numFields := t.NumField()
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
}
switch fType.Kind() {
case reflect.Struct:
if chID := _reflectID(v.Field(i), thing); chID.Valid() {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
}
}
}
return 0
}

View File

@ -342,79 +342,3 @@ func BenchmarkHelp(b *testing.B) {
_ = ctx.Help()
}
}
type hasID struct {
ChannelID discord.Snowflake
}
type embedsID struct {
*hasID
*embedsID
}
type hasChannelInName struct {
ID discord.Snowflake
}
func TestReflectChannelID(t *testing.T) {
var s = &hasID{
ChannelID: 69420,
}
t.Run("hasID", func(t *testing.T) {
if id := reflectChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("embedsID", func(t *testing.T) {
var e = &embedsID{
hasID: s,
}
if id := reflectChannelID(e); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("hasChannelInName", func(t *testing.T) {
var s = &hasChannelInName{
ID: 69420,
}
if id := reflectChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
}
func BenchmarkReflectChannelID_1Level(b *testing.B) {
var s = &hasID{
ChannelID: 69420,
}
for i := 0; i < b.N; i++ {
_ = reflectChannelID(s)
}
}
func BenchmarkReflectChannelID_5Level(b *testing.B) {
var s = &embedsID{
nil,
&embedsID{
nil,
&embedsID{
nil,
&embedsID{
hasID: &hasID{
ChannelID: 69420,
},
},
},
},
}
for i := 0; i < b.N; i++ {
_ = reflectChannelID(s)
}
}

82
bot/extras/infer/infer.go Normal file
View File

@ -0,0 +1,82 @@
// Package infer implements reflect functions that package bot uses.
//
// Functions in this package may run recursively forever. This shouldn't happen
// with Arikawa's structures, but use these functions with care.
package infer
import (
"reflect"
"strings"
"github.com/diamondburned/arikawa/discord"
)
// ChannelID looks for fields with name ChannelID, Channel, or in some special
// cases, ID.
func ChannelID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "Channel")
}
// GuildID looks for fields with name GuildID, Guild, or in some special cases,
// ID.
func GuildID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "Guild")
}
// UserID looks for fields with name UserID, User, or in some special cases, ID.
func UserID(event interface{}) discord.Snowflake {
return reflectID(reflect.ValueOf(event), "User")
}
func reflectID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
numFields := t.NumField()
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
}
switch fType.Kind() {
case reflect.Struct:
if chID := reflectID(v.Field(i), thing); chID.Valid() {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
}
}
}
return 0
}

View File

@ -0,0 +1,85 @@
package infer
import (
"testing"
"github.com/diamondburned/arikawa/discord"
)
type hasID struct {
ChannelID discord.Snowflake
}
type embedsID struct {
*hasID
*embedsID
}
type hasChannelInName struct {
ID discord.Snowflake
}
func TestReflectChannelID(t *testing.T) {
var s = &hasID{
ChannelID: 69420,
}
t.Run("hasID", func(t *testing.T) {
if id := ChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("embedsID", func(t *testing.T) {
var e = &embedsID{
hasID: s,
}
if id := ChannelID(e); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
t.Run("hasChannelInName", func(t *testing.T) {
var s = &hasChannelInName{
ID: 69420,
}
if id := ChannelID(s); id != 69420 {
t.Fatal("unexpected channelID:", id)
}
})
}
var id discord.Snowflake
func BenchmarkReflectChannelID_1Level(b *testing.B) {
var s = &hasID{
ChannelID: 69420,
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
}
}
func BenchmarkReflectChannelID_5Level(b *testing.B) {
var s = &embedsID{
nil,
&embedsID{
nil,
&embedsID{
nil,
&embedsID{
hasID: &hasID{
ChannelID: 69420,
},
},
},
},
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
}
}

View File

@ -13,7 +13,7 @@ const None NameFlag = 0
// These flags are applied to all events, if possible. The defined behavior
// is to search for "ChannelID" fields or "ID" fields in structs with
// "Channel" in its name. It doesn't handle individual events, as such, will
// not be able to guarantee it will always work.
// not be able to guarantee it will always work. Refer to package infer.
// R - Raw, which tells the library to use the method name as-is (flags will
// still be stripped). For example, if a method is called Reset its