mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-02-08 04:28:32 +00:00
Bot: Added tests for middlewares
This commit is contained in:
parent
964e8cdf13
commit
9e59402591
|
@ -25,6 +25,8 @@ func GuildID(event interface{}) discord.Snowflake {
|
|||
|
||||
// UserID looks for fields with name UserID, User, or in some special cases, ID.
|
||||
func UserID(event interface{}) discord.Snowflake {
|
||||
// This may have a very fatal bug of accidentally mistaking another User's
|
||||
// ID. It also probably wouldn't work with things like RecipientID.
|
||||
return reflectID(reflect.ValueOf(event), "User")
|
||||
}
|
||||
|
||||
|
@ -66,13 +68,13 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
|
|||
return chID
|
||||
}
|
||||
case reflect.Int64:
|
||||
if field.Name == thing+"ID" {
|
||||
// grab value real quick
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
switch {
|
||||
case false,
|
||||
// Contains works with "LastMessageID" and such.
|
||||
strings.Contains(field.Name, thing+"ID"),
|
||||
// Special case where the struct name has Channel in it.
|
||||
field.Name == "ID" && strings.Contains(t.Name(), thing):
|
||||
|
||||
// Special case where the struct name has Channel in it
|
||||
if field.Name == "ID" && strings.Contains(t.Name(), thing) {
|
||||
return discord.Snowflake(v.Field(i).Int())
|
||||
}
|
||||
}
|
||||
|
@ -80,3 +82,185 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake {
|
|||
|
||||
return 0
|
||||
}
|
||||
|
||||
/*
|
||||
var reflectCache sync.Map
|
||||
|
||||
type cacheKey struct {
|
||||
t reflect.Type
|
||||
f string
|
||||
}
|
||||
|
||||
func getID(v reflect.Value, thing string) discord.Snowflake {
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
|
||||
if t.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
|
||||
// Recheck after dereferring
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
|
||||
t = v.Type()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return 0
|
||||
}
|
||||
|
||||
return reflectID(thing, v, t)
|
||||
}
|
||||
|
||||
type reflector struct {
|
||||
steps []step
|
||||
thing string
|
||||
thingID string
|
||||
}
|
||||
|
||||
type step struct {
|
||||
field int
|
||||
ptr bool
|
||||
rec []step
|
||||
}
|
||||
|
||||
func reflectID(thing string, v reflect.Value, t reflect.Type) discord.Snowflake {
|
||||
r := &reflector{thing: thing}
|
||||
|
||||
// copy original type
|
||||
key := r.thing + t.String()
|
||||
|
||||
// check the cache
|
||||
if instructions, ok := reflectCache.Load(key); ok {
|
||||
if instructions == nil {
|
||||
return 0
|
||||
}
|
||||
return applyInstructions(v, instructions.([]step))
|
||||
}
|
||||
|
||||
r.thingID = r.thing + "ID"
|
||||
r.steps = make([]step, 0, 1)
|
||||
id := r._id(v, t)
|
||||
|
||||
if r.steps != nil {
|
||||
reflectCache.Store(key, r.instructions())
|
||||
}
|
||||
|
||||
return id
|
||||
}
|
||||
|
||||
func applyInstructions(v reflect.Value, instructions []step) discord.Snowflake {
|
||||
// Use a type here to detect recursion:
|
||||
// var originalT = v.Type()
|
||||
var laststep reflect.Value
|
||||
|
||||
log.Println(v.Type(), instructions)
|
||||
|
||||
for i, step := range instructions {
|
||||
if !v.IsValid() {
|
||||
return 0
|
||||
}
|
||||
if i > 0 && step.ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
if !v.IsValid() {
|
||||
// is this the bottom of the instructions?
|
||||
if i == len(instructions)-1 && step.rec != nil {
|
||||
for _, ins := range step.rec {
|
||||
var value = laststep.Field(ins.field)
|
||||
if ins.ptr {
|
||||
value = value.Elem()
|
||||
}
|
||||
if id := applyInstructions(value, instructions); id.Valid() {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
laststep = v
|
||||
v = laststep.Field(step.field)
|
||||
}
|
||||
return discord.Snowflake(v.Int())
|
||||
}
|
||||
|
||||
func (r *reflector) instructions() []step {
|
||||
if len(r.steps) == 0 {
|
||||
return nil
|
||||
}
|
||||
var instructions = make([]step, len(r.steps))
|
||||
for i := 0; i < len(instructions); i++ {
|
||||
instructions[i] = r.steps[len(r.steps)-i-1]
|
||||
}
|
||||
// instructions := r.steps
|
||||
return instructions
|
||||
}
|
||||
|
||||
func (r *reflector) step(s step) {
|
||||
r.steps = append(r.steps, s)
|
||||
}
|
||||
|
||||
func (r *reflector) _id(v reflect.Value, t reflect.Type) (chID discord.Snowflake) {
|
||||
numFields := t.NumField()
|
||||
|
||||
var ptr bool
|
||||
var ins = step{field: -1}
|
||||
|
||||
for i := 0; i < numFields; i++ {
|
||||
field := t.Field(i)
|
||||
fType := field.Type
|
||||
value := v.Field(i)
|
||||
ptr = false
|
||||
|
||||
if fType.Kind() == reflect.Ptr {
|
||||
fType = fType.Elem()
|
||||
value = value.Elem()
|
||||
ptr = true
|
||||
}
|
||||
|
||||
// does laststep have the same field type?
|
||||
if fType == t {
|
||||
ins.rec = append(ins.rec, step{field: i, ptr: ptr})
|
||||
}
|
||||
|
||||
if !value.IsValid() {
|
||||
continue
|
||||
}
|
||||
|
||||
// If we've already found the field:
|
||||
if ins.field > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch fType.Kind() {
|
||||
case reflect.Struct:
|
||||
if chID = r._id(value, fType); chID.Valid() {
|
||||
ins.field = i
|
||||
ins.ptr = ptr
|
||||
}
|
||||
case reflect.Int64:
|
||||
switch {
|
||||
case false,
|
||||
// Contains works with "LastMessageID" and such.
|
||||
strings.Contains(field.Name, r.thingID),
|
||||
// Special case where the struct name has Channel in it.
|
||||
field.Name == "ID" && strings.Contains(t.Name(), r.thing):
|
||||
|
||||
ins.field = i
|
||||
ins.ptr = ptr
|
||||
|
||||
chID = discord.Snowflake(value.Int())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we've found the field:
|
||||
r.step(ins)
|
||||
|
||||
return
|
||||
}
|
||||
*/
|
||||
|
|
|
@ -51,15 +51,15 @@ func TestReflectChannelID(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
var id discord.Snowflake
|
||||
|
||||
func BenchmarkReflectChannelID_1Level(b *testing.B) {
|
||||
var s = &hasID{
|
||||
ChannelID: 69420,
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id = ChannelID(s)
|
||||
if id := ChannelID(s); id != s.ChannelID {
|
||||
b.Fatal("Unexpected ChannelID:", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -80,6 +80,8 @@ func BenchmarkReflectChannelID_5Level(b *testing.B) {
|
|||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id = ChannelID(s)
|
||||
if id := ChannelID(s); id != 69420 {
|
||||
b.Fatal("Unexpected ChannelID:", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ func AdminOnly(ctx *bot.Context) func(interface{}) error {
|
|||
return bot.Break
|
||||
}
|
||||
|
||||
p, err := ctx.State.Permissions(channelID, userID)
|
||||
p, err := ctx.Permissions(channelID, userID)
|
||||
if err == nil && p.Has(discord.PermissionAdministrator) {
|
||||
return nil
|
||||
}
|
||||
|
@ -39,7 +39,7 @@ func GuildOnly(ctx *bot.Context) func(interface{}) error {
|
|||
return bot.Break
|
||||
}
|
||||
|
||||
c, err := ctx.State.Channel(channelID)
|
||||
c, err := ctx.Channel(channelID)
|
||||
if err != nil || !c.GuildID.Valid() {
|
||||
return bot.Break
|
||||
}
|
||||
|
|
194
bot/extras/middlewares/middlewares_test.go
Normal file
194
bot/extras/middlewares/middlewares_test.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
package middlewares
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/diamondburned/arikawa/bot"
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/state"
|
||||
)
|
||||
|
||||
func TestAdminOnly(t *testing.T) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Store: &mockStore{},
|
||||
},
|
||||
}
|
||||
var middleware = AdminOnly(ctx)
|
||||
|
||||
t.Run("allow message", func(t *testing.T) {
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 1,
|
||||
ChannelID: 1337,
|
||||
Author: discord.User{ID: 69420},
|
||||
},
|
||||
}
|
||||
expectNil(t, middleware(msg))
|
||||
})
|
||||
|
||||
t.Run("deny message", func(t *testing.T) {
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 2,
|
||||
ChannelID: 1337,
|
||||
Author: discord.User{ID: 1337},
|
||||
},
|
||||
}
|
||||
expectBreak(t, middleware(msg))
|
||||
var pin = &gateway.ChannelPinsUpdateEvent{
|
||||
ChannelID: 120,
|
||||
}
|
||||
expectBreak(t, middleware(pin))
|
||||
var tpg = &gateway.TypingStartEvent{}
|
||||
expectBreak(t, middleware(tpg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGuildOnly(t *testing.T) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Store: &mockStore{},
|
||||
},
|
||||
}
|
||||
var middleware = GuildOnly(ctx)
|
||||
|
||||
t.Run("allow message with GuildID", func(t *testing.T) {
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 3,
|
||||
GuildID: 1337,
|
||||
},
|
||||
}
|
||||
expectNil(t, middleware(msg))
|
||||
})
|
||||
|
||||
t.Run("allow message with ChannelID", func(t *testing.T) {
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 3,
|
||||
ChannelID: 69420,
|
||||
},
|
||||
}
|
||||
expectNil(t, middleware(msg))
|
||||
})
|
||||
|
||||
t.Run("deny message", func(t *testing.T) {
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 1,
|
||||
ChannelID: 12,
|
||||
},
|
||||
}
|
||||
expectBreak(t, middleware(msg))
|
||||
|
||||
var msg2 = &gateway.MessageCreateEvent{}
|
||||
expectBreak(t, middleware(msg2))
|
||||
})
|
||||
}
|
||||
|
||||
func expectNil(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectBreak(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if errors.Is(err, bot.Break) {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal("Unexpected error:", err)
|
||||
}
|
||||
t.Fatal("Expected error, got nothing.")
|
||||
}
|
||||
|
||||
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
|
||||
// calculate the overhead of reflection.
|
||||
func BenchmarkGuildOnly(b *testing.B) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Store: &mockStore{},
|
||||
},
|
||||
}
|
||||
var middleware = GuildOnly(ctx)
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 3,
|
||||
GuildID: 1337,
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := middleware(msg); err != nil {
|
||||
b.Fatal("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
|
||||
// calculate the overhead of reflection.
|
||||
func BenchmarkAdminOnly(b *testing.B) {
|
||||
var ctx = &bot.Context{
|
||||
State: &state.State{
|
||||
Store: &mockStore{},
|
||||
},
|
||||
}
|
||||
var middleware = AdminOnly(ctx)
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Message: discord.Message{
|
||||
ID: 1,
|
||||
ChannelID: 1337,
|
||||
Author: discord.User{ID: 69420},
|
||||
},
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
if err := middleware(msg); err != nil {
|
||||
b.Fatal("Unexpected error:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockStore struct {
|
||||
state.NoopStore
|
||||
}
|
||||
|
||||
func (s *mockStore) Guild(id discord.Snowflake) (*discord.Guild, error) {
|
||||
return &discord.Guild{
|
||||
ID: id,
|
||||
Roles: []discord.Role{{
|
||||
ID: 69420,
|
||||
Permissions: discord.PermissionAdministrator,
|
||||
}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *mockStore) Member(g, m discord.Snowflake) (*discord.Member, error) {
|
||||
return &discord.Member{
|
||||
User: discord.User{ID: m},
|
||||
RoleIDs: []discord.Snowflake{m},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Channel returns a channel with a guildID for #69420.
|
||||
func (s *mockStore) Channel(chID discord.Snowflake) (*discord.Channel, error) {
|
||||
if chID == 69420 {
|
||||
return &discord.Channel{
|
||||
ID: chID,
|
||||
GuildID: 1337,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &discord.Channel{
|
||||
ID: chID,
|
||||
}, nil
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
package main
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAdminOnly(t *testing.T) {
|
||||
t.Fatal("Do me.")
|
||||
}
|
||||
|
||||
func TestGuildOnly(t *testing.T) {
|
||||
t.Fatal("Do me.")
|
||||
}
|
Loading…
Reference in a new issue