mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-11-30 18:53:30 +00:00
Bot: Added package infer for getting IDs from unknown structs
This commit is contained in:
parent
9219d2fc40
commit
7dbdc78d67
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
"github.com/diamondburned/arikawa/bot/extras/infer"
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/pkg/errors"
|
||||
|
@ -441,12 +442,12 @@ func (ctx *Context) eventIsAdmin(ev interface{}, is **bool) bool {
|
|||
return **is
|
||||
}
|
||||
|
||||
var channelID = reflectChannelID(ev)
|
||||
var channelID = infer.ChannelID(ev)
|
||||
if !channelID.Valid() {
|
||||
return false
|
||||
}
|
||||
|
||||
var userID = reflectUserID(ev)
|
||||
var userID = infer.UserID(ev)
|
||||
if !userID.Valid() {
|
||||
return false
|
||||
}
|
||||
|
@ -467,7 +468,7 @@ func (ctx *Context) eventIsGuild(ev interface{}, is **bool) bool {
|
|||
return **is
|
||||
}
|
||||
|
||||
var channelID = reflectChannelID(ev)
|
||||
var channelID = infer.ChannelID(ev)
|
||||
if !channelID.Valid() {
|
||||
return false
|
||||
}
|
||||
|
@ -524,68 +525,3 @@ func errorReturns(returns []reflect.Value) (interface{}, error) {
|
|||
// Treat the last return as an error.
|
||||
return nil, v.(error)
|
||||
}
|
||||
|
||||
func reflectChannelID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "Channel")
|
||||
}
|
||||
|
||||
func reflectGuildID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "Guild")
|
||||
}
|
||||
|
||||
func reflectUserID(_struct interface{}) discord.Snowflake {
|
||||
return _reflectID(reflect.ValueOf(_struct), "User")
|
||||
}
|
||||
|
||||
func _reflectID(v reflect.Value, thing string) discord.Snowflake {
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
|
||||
// Recheck after dereferring
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t = v.Type()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return 0
|
||||
}
|
||||
|
||||
numFields := t.NumField()
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := t.Field(i)
|
||||
fType := field.Type
|
||||
|
||||
if fType.Kind() == reflect.Ptr {
|
||||
fType = fType.Elem()
|
||||
}
|
||||
|
||||
switch fType.Kind() {
|
||||
case reflect.Struct:
|
||||
if chID := _reflectID(v.Field(i), thing); chID.Valid() {
|
||||
return chID
|
||||
}
|
||||
case reflect.Int64:
|
||||
if field.Name == thing+"ID" {
|
||||
// grab value real quick
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
|
||||
// Special case where the struct name has Channel in it
|
||||
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
|
|
@ -342,79 +342,3 @@ func BenchmarkHelp(b *testing.B) {
|
|||
_ = ctx.Help()
|
||||
}
|
||||
}
|
||||
|
||||
type hasID struct {
|
||||
ChannelID discord.Snowflake
|
||||
}
|
||||
|
||||
type embedsID struct {
|
||||
*hasID
|
||||
*embedsID
|
||||
}
|
||||
|
||||
type hasChannelInName struct {
|
||||
ID discord.Snowflake
|
||||
}
|
||||
|
||||
func TestReflectChannelID(t *testing.T) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
t.Run("hasID", func(t *testing.T) {
|
||||
if id := reflectChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("embedsID", func(t *testing.T) {
|
||||
var e = &embedsID{
|
||||
hasID: s,
|
||||
}
|
||||
|
||||
if id := reflectChannelID(e); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hasChannelInName", func(t *testing.T) {
|
||||
var s = &hasChannelInName{
|
||||
ID: 69420,
|
||||
}
|
||||
|
||||
if id := reflectChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = reflectChannelID(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReflectChannelID_5Level(b *testing.B) {
|
||||
var s = &embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
hasID: &hasID{
|
||||
ChannelID: 69420,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = reflectChannelID(s)
|
||||
}
|
||||
}
|
||||
|
|
82
bot/extras/infer/infer.go
Normal file
82
bot/extras/infer/infer.go
Normal file
|
@ -0,0 +1,82 @@
|
|||
// Package infer implements reflect functions that package bot uses.
|
||||
//
|
||||
// Functions in this package may run recursively forever. This shouldn't happen
|
||||
// with Arikawa's structures, but use these functions with care.
|
||||
package infer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
)
|
||||
|
||||
// ChannelID looks for fields with name ChannelID, Channel, or in some special
|
||||
// cases, ID.
|
||||
func ChannelID(event interface{}) discord.Snowflake {
|
||||
return reflectID(reflect.ValueOf(event), "Channel")
|
||||
}
|
||||
|
||||
// GuildID looks for fields with name GuildID, Guild, or in some special cases,
|
||||
// ID.
|
||||
func GuildID(event interface{}) discord.Snowflake {
|
||||
return reflectID(reflect.ValueOf(event), "Guild")
|
||||
}
|
||||
|
||||
// UserID looks for fields with name UserID, User, or in some special cases, ID.
|
||||
func UserID(event interface{}) discord.Snowflake {
|
||||
return reflectID(reflect.ValueOf(event), "User")
|
||||
}
|
||||
|
||||
func reflectID(v reflect.Value, thing string) discord.Snowflake {
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
|
||||
// Recheck after dereferring
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t = v.Type()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return 0
|
||||
}
|
||||
|
||||
numFields := t.NumField()
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := t.Field(i)
|
||||
fType := field.Type
|
||||
|
||||
if fType.Kind() == reflect.Ptr {
|
||||
fType = fType.Elem()
|
||||
}
|
||||
|
||||
switch fType.Kind() {
|
||||
case reflect.Struct:
|
||||
if chID := reflectID(v.Field(i), thing); chID.Valid() {
|
||||
return chID
|
||||
}
|
||||
case reflect.Int64:
|
||||
if field.Name == thing+"ID" {
|
||||
// grab value real quick
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
|
||||
// Special case where the struct name has Channel in it
|
||||
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
85
bot/extras/infer/infer_test.go
Normal file
85
bot/extras/infer/infer_test.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package infer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
)
|
||||
|
||||
type hasID struct {
|
||||
ChannelID discord.Snowflake
|
||||
}
|
||||
|
||||
type embedsID struct {
|
||||
*hasID
|
||||
*embedsID
|
||||
}
|
||||
|
||||
type hasChannelInName struct {
|
||||
ID discord.Snowflake
|
||||
}
|
||||
|
||||
func TestReflectChannelID(t *testing.T) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
t.Run("hasID", func(t *testing.T) {
|
||||
if id := ChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("embedsID", func(t *testing.T) {
|
||||
var e = &embedsID{
|
||||
hasID: s,
|
||||
}
|
||||
|
||||
if id := ChannelID(e); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("hasChannelInName", func(t *testing.T) {
|
||||
var s = &hasChannelInName{
|
||||
ID: 69420,
|
||||
}
|
||||
|
||||
if id := ChannelID(s); id != 69420 {
|
||||
t.Fatal("unexpected channelID:", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var id discord.Snowflake
|
||||
|
||||
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id = ChannelID(s)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReflectChannelID_5Level(b *testing.B) {
|
||||
var s = &embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
nil,
|
||||
&embedsID{
|
||||
hasID: &hasID{
|
||||
ChannelID: 69420,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id = ChannelID(s)
|
||||
}
|
||||
}
|
|
@ -13,7 +13,7 @@ const None NameFlag = 0
|
|||
// These flags are applied to all events, if possible. The defined behavior
|
||||
// is to search for "ChannelID" fields or "ID" fields in structs with
|
||||
// "Channel" in its name. It doesn't handle individual events, as such, will
|
||||
// not be able to guarantee it will always work.
|
||||
// not be able to guarantee it will always work. Refer to package infer.
|
||||
|
||||
// R - Raw, which tells the library to use the method name as-is (flags will
|
||||
// still be stripped). For example, if a method is called Reset its
|
||||
|
|
Loading…
Reference in a new issue