mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-03-25 03:19:20 +00:00
Bot: Added tests for middlewares
This commit is contained in:
parent
964e8cdf13
commit
9e59402591
|
@ -25,6 +25,8 @@ func GuildID(event interface{}) discord.Snowflake {
|
||||||
|
|
||||||
// UserID looks for fields with name UserID, User, or in some special cases, ID.
|
// UserID looks for fields with name UserID, User, or in some special cases, ID.
|
||||||
func UserID(event interface{}) discord.Snowflake {
|
func UserID(event interface{}) discord.Snowflake {
|
||||||
|
// This may have a very fatal bug of accidentally mistaking another User's
|
||||||
|
// ID. It also probably wouldn't work with things like RecipientID.
|
||||||
return reflectID(reflect.ValueOf(event), "User")
|
return reflectID(reflect.ValueOf(event), "User")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,13 +68,13 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
|
||||||
return chID
|
return chID
|
||||||
}
|
}
|
||||||
case reflect.Int64:
|
case reflect.Int64:
|
||||||
if field.Name == thing+"ID" {
|
switch {
|
||||||
// grab value real quick
|
case false,
|
||||||
return discord.Snowflake(v.Field(i).Int())
|
// Contains works with "LastMessageID" and such.
|
||||||
}
|
strings.Contains(field.Name, thing+"ID"),
|
||||||
|
// Special case where the struct name has Channel in it.
|
||||||
|
field.Name == "ID" && strings.Contains(t.Name(), thing):
|
||||||
|
|
||||||
// Special case where the struct name has Channel in it
|
|
||||||
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
|
|
||||||
return discord.Snowflake(v.Field(i).Int())
|
return discord.Snowflake(v.Field(i).Int())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -80,3 +82,185 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
var reflectCache sync.Map
|
||||||
|
|
||||||
|
type cacheKey struct {
|
||||||
|
t reflect.Type
|
||||||
|
f string
|
||||||
|
}
|
||||||
|
|
||||||
|
func getID(v reflect.Value, thing string) discord.Snowflake {
|
||||||
|
if !v.IsValid() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
t := v.Type()
|
||||||
|
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
|
||||||
|
// Recheck after dereferring
|
||||||
|
if !v.IsValid() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
t = v.Type()
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.Kind() != reflect.Struct {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflectID(thing, v, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
type reflector struct {
|
||||||
|
steps []step
|
||||||
|
thing string
|
||||||
|
thingID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type step struct {
|
||||||
|
field int
|
||||||
|
ptr bool
|
||||||
|
rec []step
|
||||||
|
}
|
||||||
|
|
||||||
|
func reflectID(thing string, v reflect.Value, t reflect.Type) discord.Snowflake {
|
||||||
|
r := &reflector{thing: thing}
|
||||||
|
|
||||||
|
// copy original type
|
||||||
|
key := r.thing + t.String()
|
||||||
|
|
||||||
|
// check the cache
|
||||||
|
if instructions, ok := reflectCache.Load(key); ok {
|
||||||
|
if instructions == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return applyInstructions(v, instructions.([]step))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.thingID = r.thing + "ID"
|
||||||
|
r.steps = make([]step, 0, 1)
|
||||||
|
id := r._id(v, t)
|
||||||
|
|
||||||
|
if r.steps != nil {
|
||||||
|
reflectCache.Store(key, r.instructions())
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyInstructions(v reflect.Value, instructions []step) discord.Snowflake {
|
||||||
|
// Use a type here to detect recursion:
|
||||||
|
// var originalT = v.Type()
|
||||||
|
var laststep reflect.Value
|
||||||
|
|
||||||
|
log.Println(v.Type(), instructions)
|
||||||
|
|
||||||
|
for i, step := range instructions {
|
||||||
|
if !v.IsValid() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if i > 0 && step.ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
if !v.IsValid() {
|
||||||
|
// is this the bottom of the instructions?
|
||||||
|
if i == len(instructions)-1 && step.rec != nil {
|
||||||
|
for _, ins := range step.rec {
|
||||||
|
var value = laststep.Field(ins.field)
|
||||||
|
if ins.ptr {
|
||||||
|
value = value.Elem()
|
||||||
|
}
|
||||||
|
if id := applyInstructions(value, instructions); id.Valid() {
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
laststep = v
|
||||||
|
v = laststep.Field(step.field)
|
||||||
|
}
|
||||||
|
return discord.Snowflake(v.Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *reflector) instructions() []step {
|
||||||
|
if len(r.steps) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var instructions = make([]step, len(r.steps))
|
||||||
|
for i := 0; i < len(instructions); i++ {
|
||||||
|
instructions[i] = r.steps[len(r.steps)-i-1]
|
||||||
|
}
|
||||||
|
// instructions := r.steps
|
||||||
|
return instructions
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *reflector) step(s step) {
|
||||||
|
r.steps = append(r.steps, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *reflector) _id(v reflect.Value, t reflect.Type) (chID discord.Snowflake) {
|
||||||
|
numFields := t.NumField()
|
||||||
|
|
||||||
|
var ptr bool
|
||||||
|
var ins = step{field: -1}
|
||||||
|
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
field := t.Field(i)
|
||||||
|
fType := field.Type
|
||||||
|
value := v.Field(i)
|
||||||
|
ptr = false
|
||||||
|
|
||||||
|
if fType.Kind() == reflect.Ptr {
|
||||||
|
fType = fType.Elem()
|
||||||
|
value = value.Elem()
|
||||||
|
ptr = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// does laststep have the same field type?
|
||||||
|
if fType == t {
|
||||||
|
ins.rec = append(ins.rec, step{field: i, ptr: ptr})
|
||||||
|
}
|
||||||
|
|
||||||
|
if !value.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've already found the field:
|
||||||
|
if ins.field > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch fType.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
if chID = r._id(value, fType); chID.Valid() {
|
||||||
|
ins.field = i
|
||||||
|
ins.ptr = ptr
|
||||||
|
}
|
||||||
|
case reflect.Int64:
|
||||||
|
switch {
|
||||||
|
case false,
|
||||||
|
// Contains works with "LastMessageID" and such.
|
||||||
|
strings.Contains(field.Name, r.thingID),
|
||||||
|
// Special case where the struct name has Channel in it.
|
||||||
|
field.Name == "ID" && strings.Contains(t.Name(), r.thing):
|
||||||
|
|
||||||
|
ins.field = i
|
||||||
|
ins.ptr = ptr
|
||||||
|
|
||||||
|
chID = discord.Snowflake(value.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we've found the field:
|
||||||
|
r.step(ins)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
|
@ -51,15 +51,15 @@ func TestReflectChannelID(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
var id discord.Snowflake
|
|
||||||
|
|
||||||
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
||||||
var s = &hasID{
|
var s = &hasID{
|
||||||
ChannelID: 69420,
|
ChannelID: 69420,
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
id = ChannelID(s)
|
if id := ChannelID(s); id != s.ChannelID {
|
||||||
|
b.Fatal("Unexpected ChannelID:", id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,6 +80,8 @@ func BenchmarkReflectChannelID_5Level(b *testing.B) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
id = ChannelID(s)
|
if id := ChannelID(s); id != 69420 {
|
||||||
|
b.Fatal("Unexpected ChannelID:", id)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ func AdminOnly(ctx *bot.Context) func(interface{}) error {
|
||||||
return bot.Break
|
return bot.Break
|
||||||
}
|
}
|
||||||
|
|
||||||
p, err := ctx.State.Permissions(channelID, userID)
|
p, err := ctx.Permissions(channelID, userID)
|
||||||
if err == nil && p.Has(discord.PermissionAdministrator) {
|
if err == nil && p.Has(discord.PermissionAdministrator) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -39,7 +39,7 @@ func GuildOnly(ctx *bot.Context) func(interface{}) error {
|
||||||
return bot.Break
|
return bot.Break
|
||||||
}
|
}
|
||||||
|
|
||||||
c, err := ctx.State.Channel(channelID)
|
c, err := ctx.Channel(channelID)
|
||||||
if err != nil || !c.GuildID.Valid() {
|
if err != nil || !c.GuildID.Valid() {
|
||||||
return bot.Break
|
return bot.Break
|
||||||
}
|
}
|
||||||
|
|
194
bot/extras/middlewares/middlewares_test.go
Normal file
194
bot/extras/middlewares/middlewares_test.go
Normal file
|
@ -0,0 +1,194 @@
|
||||||
|
package middlewares
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/bot"
|
||||||
|
"github.com/diamondburned/arikawa/discord"
|
||||||
|
"github.com/diamondburned/arikawa/gateway"
|
||||||
|
"github.com/diamondburned/arikawa/state"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminOnly(t *testing.T) {
|
||||||
|
var ctx = &bot.Context{
|
||||||
|
State: &state.State{
|
||||||
|
Store: &mockStore{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var middleware = AdminOnly(ctx)
|
||||||
|
|
||||||
|
t.Run("allow message", func(t *testing.T) {
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1337,
|
||||||
|
Author: discord.User{ID: 69420},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expectNil(t, middleware(msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deny message", func(t *testing.T) {
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 2,
|
||||||
|
ChannelID: 1337,
|
||||||
|
Author: discord.User{ID: 1337},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expectBreak(t, middleware(msg))
|
||||||
|
var pin = &gateway.ChannelPinsUpdateEvent{
|
||||||
|
ChannelID: 120,
|
||||||
|
}
|
||||||
|
expectBreak(t, middleware(pin))
|
||||||
|
var tpg = &gateway.TypingStartEvent{}
|
||||||
|
expectBreak(t, middleware(tpg))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGuildOnly(t *testing.T) {
|
||||||
|
var ctx = &bot.Context{
|
||||||
|
State: &state.State{
|
||||||
|
Store: &mockStore{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var middleware = GuildOnly(ctx)
|
||||||
|
|
||||||
|
t.Run("allow message with GuildID", func(t *testing.T) {
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 3,
|
||||||
|
GuildID: 1337,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expectNil(t, middleware(msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allow message with ChannelID", func(t *testing.T) {
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 3,
|
||||||
|
ChannelID: 69420,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expectNil(t, middleware(msg))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deny message", func(t *testing.T) {
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 12,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
expectBreak(t, middleware(msg))
|
||||||
|
|
||||||
|
var msg2 = &gateway.MessageCreateEvent{}
|
||||||
|
expectBreak(t, middleware(msg2))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectNil(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func expectBreak(t *testing.T, err error) {
|
||||||
|
t.Helper()
|
||||||
|
if errors.Is(err, bot.Break) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
t.Fatal("Expected error, got nothing.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
|
||||||
|
// calculate the overhead of reflection.
|
||||||
|
func BenchmarkGuildOnly(b *testing.B) {
|
||||||
|
var ctx = &bot.Context{
|
||||||
|
State: &state.State{
|
||||||
|
Store: &mockStore{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var middleware = GuildOnly(ctx)
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 3,
|
||||||
|
GuildID: 1337,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := middleware(msg); err != nil {
|
||||||
|
b.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
|
||||||
|
// calculate the overhead of reflection.
|
||||||
|
func BenchmarkAdminOnly(b *testing.B) {
|
||||||
|
var ctx = &bot.Context{
|
||||||
|
State: &state.State{
|
||||||
|
Store: &mockStore{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
var middleware = AdminOnly(ctx)
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Message: discord.Message{
|
||||||
|
ID: 1,
|
||||||
|
ChannelID: 1337,
|
||||||
|
Author: discord.User{ID: 69420},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := middleware(msg); err != nil {
|
||||||
|
b.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStore struct {
|
||||||
|
state.NoopStore
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockStore) Guild(id discord.Snowflake) (*discord.Guild, error) {
|
||||||
|
return &discord.Guild{
|
||||||
|
ID: id,
|
||||||
|
Roles: []discord.Role{{
|
||||||
|
ID: 69420,
|
||||||
|
Permissions: discord.PermissionAdministrator,
|
||||||
|
}},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *mockStore) Member(g, m discord.Snowflake) (*discord.Member, error) {
|
||||||
|
return &discord.Member{
|
||||||
|
User: discord.User{ID: m},
|
||||||
|
RoleIDs: []discord.Snowflake{m},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channel returns a channel with a guildID for #69420.
|
||||||
|
func (s *mockStore) Channel(chID discord.Snowflake) (*discord.Channel, error) {
|
||||||
|
if chID == 69420 {
|
||||||
|
return &discord.Channel{
|
||||||
|
ID: chID,
|
||||||
|
GuildID: 1337,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &discord.Channel{
|
||||||
|
ID: chID,
|
||||||
|
}, nil
|
||||||
|
}
|
|
@ -1,11 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestAdminOnly(t *testing.T) {
|
|
||||||
t.Fatal("Do me.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGuildOnly(t *testing.T) {
|
|
||||||
t.Fatal("Do me.")
|
|
||||||
}
|
|
Loading…
Reference in a new issue