Bot: Added tests for middlewares

This commit is contained in:
diamondburned (Forefront) 2020-05-10 18:45:42 -07:00
parent 964e8cdf13
commit 9e59402591
5 changed files with 392 additions and 23 deletions

View File

@ -25,6 +25,8 @@ func GuildID(event interface{}) discord.Snowflake {
// UserID looks for fields with name UserID, User, or in some special cases, ID.
func UserID(event interface{}) discord.Snowflake {
// This may have a very fatal bug of accidentally mistaking another User's
// ID. It also probably wouldn't work with things like RecipientID.
return reflectID(reflect.ValueOf(event), "User")
}
@ -66,13 +68,13 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
return chID
}
case reflect.Int64:
if field.Name == thing+"ID" {
// grab value real quick
return discord.Snowflake(v.Field(i).Int())
}
switch {
case false,
// Contains works with "LastMessageID" and such.
strings.Contains(field.Name, thing+"ID"),
// Special case where the struct name has Channel in it.
field.Name == "ID" && strings.Contains(t.Name(), thing):
// Special case where the struct name has Channel in it
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
return discord.Snowflake(v.Field(i).Int())
}
}
@ -80,3 +82,185 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
return 0
}
/*
var reflectCache sync.Map
type cacheKey struct {
t reflect.Type
f string
}
func getID(v reflect.Value, thing string) discord.Snowflake {
if !v.IsValid() {
return 0
}
t := v.Type()
if t.Kind() == reflect.Ptr {
v = v.Elem()
// Recheck after dereferring
if !v.IsValid() {
return 0
}
t = v.Type()
}
if t.Kind() != reflect.Struct {
return 0
}
return reflectID(thing, v, t)
}
type reflector struct {
steps []step
thing string
thingID string
}
type step struct {
field int
ptr bool
rec []step
}
func reflectID(thing string, v reflect.Value, t reflect.Type) discord.Snowflake {
r := &reflector{thing: thing}
// copy original type
key := r.thing + t.String()
// check the cache
if instructions, ok := reflectCache.Load(key); ok {
if instructions == nil {
return 0
}
return applyInstructions(v, instructions.([]step))
}
r.thingID = r.thing + "ID"
r.steps = make([]step, 0, 1)
id := r._id(v, t)
if r.steps != nil {
reflectCache.Store(key, r.instructions())
}
return id
}
func applyInstructions(v reflect.Value, instructions []step) discord.Snowflake {
// Use a type here to detect recursion:
// var originalT = v.Type()
var laststep reflect.Value
log.Println(v.Type(), instructions)
for i, step := range instructions {
if !v.IsValid() {
return 0
}
if i > 0 && step.ptr {
v = v.Elem()
}
if !v.IsValid() {
// is this the bottom of the instructions?
if i == len(instructions)-1 && step.rec != nil {
for _, ins := range step.rec {
var value = laststep.Field(ins.field)
if ins.ptr {
value = value.Elem()
}
if id := applyInstructions(value, instructions); id.Valid() {
return id
}
}
}
return 0
}
laststep = v
v = laststep.Field(step.field)
}
return discord.Snowflake(v.Int())
}
func (r *reflector) instructions() []step {
if len(r.steps) == 0 {
return nil
}
var instructions = make([]step, len(r.steps))
for i := 0; i < len(instructions); i++ {
instructions[i] = r.steps[len(r.steps)-i-1]
}
// instructions := r.steps
return instructions
}
func (r *reflector) step(s step) {
r.steps = append(r.steps, s)
}
func (r *reflector) _id(v reflect.Value, t reflect.Type) (chID discord.Snowflake) {
numFields := t.NumField()
var ptr bool
var ins = step{field: -1}
for i := 0; i < numFields; i++ {
field := t.Field(i)
fType := field.Type
value := v.Field(i)
ptr = false
if fType.Kind() == reflect.Ptr {
fType = fType.Elem()
value = value.Elem()
ptr = true
}
// does laststep have the same field type?
if fType == t {
ins.rec = append(ins.rec, step{field: i, ptr: ptr})
}
if !value.IsValid() {
continue
}
// If we've already found the field:
if ins.field > 0 {
continue
}
switch fType.Kind() {
case reflect.Struct:
if chID = r._id(value, fType); chID.Valid() {
ins.field = i
ins.ptr = ptr
}
case reflect.Int64:
switch {
case false,
// Contains works with "LastMessageID" and such.
strings.Contains(field.Name, r.thingID),
// Special case where the struct name has Channel in it.
field.Name == "ID" && strings.Contains(t.Name(), r.thing):
ins.field = i
ins.ptr = ptr
chID = discord.Snowflake(value.Int())
}
}
}
// If we've found the field:
r.step(ins)
return
}
*/

View File

@ -51,15 +51,15 @@ func TestReflectChannelID(t *testing.T) {
})
}
var id discord.Snowflake
func BenchmarkReflectChannelID_1Level(b *testing.B) {
var s = &hasID{
ChannelID: 69420,
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
if id := ChannelID(s); id != s.ChannelID {
b.Fatal("Unexpected ChannelID:", id)
}
}
}
@ -80,6 +80,8 @@ func BenchmarkReflectChannelID_5Level(b *testing.B) {
}
for i := 0; i < b.N; i++ {
id = ChannelID(s)
if id := ChannelID(s); id != 69420 {
b.Fatal("Unexpected ChannelID:", id)
}
}
}

View File

@ -18,7 +18,7 @@ func AdminOnly(ctx *bot.Context) func(interface{}) error {
return bot.Break
}
p, err := ctx.State.Permissions(channelID, userID)
p, err := ctx.Permissions(channelID, userID)
if err == nil && p.Has(discord.PermissionAdministrator) {
return nil
}
@ -39,7 +39,7 @@ func GuildOnly(ctx *bot.Context) func(interface{}) error {
return bot.Break
}
c, err := ctx.State.Channel(channelID)
c, err := ctx.Channel(channelID)
if err != nil || !c.GuildID.Valid() {
return bot.Break
}

View File

@ -0,0 +1,194 @@
package middlewares
import (
"errors"
"testing"
"github.com/diamondburned/arikawa/bot"
"github.com/diamondburned/arikawa/discord"
"github.com/diamondburned/arikawa/gateway"
"github.com/diamondburned/arikawa/state"
)
func TestAdminOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = AdminOnly(ctx)
t.Run("allow message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 1337,
Author: discord.User{ID: 69420},
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 2,
ChannelID: 1337,
Author: discord.User{ID: 1337},
},
}
expectBreak(t, middleware(msg))
var pin = &gateway.ChannelPinsUpdateEvent{
ChannelID: 120,
}
expectBreak(t, middleware(pin))
var tpg = &gateway.TypingStartEvent{}
expectBreak(t, middleware(tpg))
})
}
func TestGuildOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = GuildOnly(ctx)
t.Run("allow message with GuildID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
expectNil(t, middleware(msg))
})
t.Run("allow message with ChannelID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
ChannelID: 69420,
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 12,
},
}
expectBreak(t, middleware(msg))
var msg2 = &gateway.MessageCreateEvent{}
expectBreak(t, middleware(msg2))
})
}
func expectNil(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal("Unexpected error:", err)
}
}
func expectBreak(t *testing.T, err error) {
t.Helper()
if errors.Is(err, bot.Break) {
return
}
if err != nil {
t.Fatal("Unexpected error:", err)
}
t.Fatal("Expected error, got nothing.")
}
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkGuildOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = GuildOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkAdminOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Store: &mockStore{},
},
}
var middleware = AdminOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 1337,
Author: discord.User{ID: 69420},
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
type mockStore struct {
state.NoopStore
}
func (s *mockStore) Guild(id discord.Snowflake) (*discord.Guild, error) {
return &discord.Guild{
ID: id,
Roles: []discord.Role{{
ID: 69420,
Permissions: discord.PermissionAdministrator,
}},
}, nil
}
func (s *mockStore) Member(g, m discord.Snowflake) (*discord.Member, error) {
return &discord.Member{
User: discord.User{ID: m},
RoleIDs: []discord.Snowflake{m},
}, nil
}
// Channel returns a channel with a guildID for #69420.
func (s *mockStore) Channel(chID discord.Snowflake) (*discord.Channel, error) {
if chID == 69420 {
return &discord.Channel{
ID: chID,
GuildID: 1337,
}, nil
}
return &discord.Channel{
ID: chID,
}, nil
}

View File

@ -1,11 +0,0 @@
package main
import "testing"
func TestAdminOnly(t *testing.T) {
t.Fatal("Do me.")
}
func TestGuildOnly(t *testing.T) {
t.Fatal("Do me.")
}