diff --git a/bot/extras/infer/infer.go b/bot/extras/infer/infer.go index 95f674e..8facefd 100644 --- a/bot/extras/infer/infer.go +++ b/bot/extras/infer/infer.go @@ -25,6 +25,8 @@ func GuildID(event interface{}) discord.Snowflake { // UserID looks for fields with name UserID, User, or in some special cases, ID. func UserID(event interface{}) discord.Snowflake { + // This may have a very fatal bug of accidentally mistaking another User's + // ID. It also probably wouldn't work with things like RecipientID. return reflectID(reflect.ValueOf(event), "User") } @@ -66,13 +68,13 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake { return chID } case reflect.Int64: - if field.Name == thing+"ID" { - // grab value real quick - return discord.Snowflake(v.Field(i).Int()) - } + switch { + case false, + // Contains works with "LastMessageID" and such. + strings.Contains(field.Name, thing+"ID"), + // Special case where the struct name has Channel in it. + field.Name == "ID" && strings.Contains(t.Name(), thing): - // Special case where the struct name has Channel in it - if field.Name == "ID" && strings.Contains(t.Name(), thing) { return discord.Snowflake(v.Field(i).Int()) } } @@ -80,3 +82,185 @@ func reflectID(v reflect.Value, thing string) discord.Snowflake { return 0 } + +/* +var reflectCache sync.Map + +type cacheKey struct { + t reflect.Type + f string +} + +func getID(v reflect.Value, thing string) discord.Snowflake { + if !v.IsValid() { + return 0 + } + + t := v.Type() + + if t.Kind() == reflect.Ptr { + v = v.Elem() + + // Recheck after dereferring + if !v.IsValid() { + return 0 + } + + t = v.Type() + } + + if t.Kind() != reflect.Struct { + return 0 + } + + return reflectID(thing, v, t) +} + +type reflector struct { + steps []step + thing string + thingID string +} + +type step struct { + field int + ptr bool + rec []step +} + +func reflectID(thing string, v reflect.Value, t reflect.Type) discord.Snowflake { + r := &reflector{thing: thing} + + // copy original type + key := r.thing + t.String() + + // check the cache + if instructions, ok := reflectCache.Load(key); ok { + if instructions == nil { + return 0 + } + return applyInstructions(v, instructions.([]step)) + } + + r.thingID = r.thing + "ID" + r.steps = make([]step, 0, 1) + id := r._id(v, t) + + if r.steps != nil { + reflectCache.Store(key, r.instructions()) + } + + return id +} + +func applyInstructions(v reflect.Value, instructions []step) discord.Snowflake { + // Use a type here to detect recursion: + // var originalT = v.Type() + var laststep reflect.Value + + log.Println(v.Type(), instructions) + + for i, step := range instructions { + if !v.IsValid() { + return 0 + } + if i > 0 && step.ptr { + v = v.Elem() + } + if !v.IsValid() { + // is this the bottom of the instructions? + if i == len(instructions)-1 && step.rec != nil { + for _, ins := range step.rec { + var value = laststep.Field(ins.field) + if ins.ptr { + value = value.Elem() + } + if id := applyInstructions(value, instructions); id.Valid() { + return id + } + } + } + return 0 + } + laststep = v + v = laststep.Field(step.field) + } + return discord.Snowflake(v.Int()) +} + +func (r *reflector) instructions() []step { + if len(r.steps) == 0 { + return nil + } + var instructions = make([]step, len(r.steps)) + for i := 0; i < len(instructions); i++ { + instructions[i] = r.steps[len(r.steps)-i-1] + } + // instructions := r.steps + return instructions +} + +func (r *reflector) step(s step) { + r.steps = append(r.steps, s) +} + +func (r *reflector) _id(v reflect.Value, t reflect.Type) (chID discord.Snowflake) { + numFields := t.NumField() + + var ptr bool + var ins = step{field: -1} + + for i := 0; i < numFields; i++ { + field := t.Field(i) + fType := field.Type + value := v.Field(i) + ptr = false + + if fType.Kind() == reflect.Ptr { + fType = fType.Elem() + value = value.Elem() + ptr = true + } + + // does laststep have the same field type? + if fType == t { + ins.rec = append(ins.rec, step{field: i, ptr: ptr}) + } + + if !value.IsValid() { + continue + } + + // If we've already found the field: + if ins.field > 0 { + continue + } + + switch fType.Kind() { + case reflect.Struct: + if chID = r._id(value, fType); chID.Valid() { + ins.field = i + ins.ptr = ptr + } + case reflect.Int64: + switch { + case false, + // Contains works with "LastMessageID" and such. + strings.Contains(field.Name, r.thingID), + // Special case where the struct name has Channel in it. + field.Name == "ID" && strings.Contains(t.Name(), r.thing): + + ins.field = i + ins.ptr = ptr + + chID = discord.Snowflake(value.Int()) + } + } + } + + // If we've found the field: + r.step(ins) + + return +} +*/ diff --git a/bot/extras/infer/infer_test.go b/bot/extras/infer/infer_test.go index 12ecd4a..04eaf9a 100644 --- a/bot/extras/infer/infer_test.go +++ b/bot/extras/infer/infer_test.go @@ -51,15 +51,15 @@ func TestReflectChannelID(t *testing.T) { }) } -var id discord.Snowflake - func BenchmarkReflectChannelID_1Level(b *testing.B) { var s = &hasID{ ChannelID: 69420, } for i := 0; i < b.N; i++ { - id = ChannelID(s) + if id := ChannelID(s); id != s.ChannelID { + b.Fatal("Unexpected ChannelID:", id) + } } } @@ -80,6 +80,8 @@ func BenchmarkReflectChannelID_5Level(b *testing.B) { } for i := 0; i < b.N; i++ { - id = ChannelID(s) + if id := ChannelID(s); id != 69420 { + b.Fatal("Unexpected ChannelID:", id) + } } } diff --git a/bot/extras/middlewares/middlewares.go b/bot/extras/middlewares/middlewares.go index 299b620..e4bd473 100644 --- a/bot/extras/middlewares/middlewares.go +++ b/bot/extras/middlewares/middlewares.go @@ -18,7 +18,7 @@ func AdminOnly(ctx *bot.Context) func(interface{}) error { return bot.Break } - p, err := ctx.State.Permissions(channelID, userID) + p, err := ctx.Permissions(channelID, userID) if err == nil && p.Has(discord.PermissionAdministrator) { return nil } @@ -39,7 +39,7 @@ func GuildOnly(ctx *bot.Context) func(interface{}) error { return bot.Break } - c, err := ctx.State.Channel(channelID) + c, err := ctx.Channel(channelID) if err != nil || !c.GuildID.Valid() { return bot.Break } diff --git a/bot/extras/middlewares/middlewares_test.go b/bot/extras/middlewares/middlewares_test.go new file mode 100644 index 0000000..fa69d08 --- /dev/null +++ b/bot/extras/middlewares/middlewares_test.go @@ -0,0 +1,194 @@ +package middlewares + +import ( + "errors" + "testing" + + "github.com/diamondburned/arikawa/bot" + "github.com/diamondburned/arikawa/discord" + "github.com/diamondburned/arikawa/gateway" + "github.com/diamondburned/arikawa/state" +) + +func TestAdminOnly(t *testing.T) { + var ctx = &bot.Context{ + State: &state.State{ + Store: &mockStore{}, + }, + } + var middleware = AdminOnly(ctx) + + t.Run("allow message", func(t *testing.T) { + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 1, + ChannelID: 1337, + Author: discord.User{ID: 69420}, + }, + } + expectNil(t, middleware(msg)) + }) + + t.Run("deny message", func(t *testing.T) { + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 2, + ChannelID: 1337, + Author: discord.User{ID: 1337}, + }, + } + expectBreak(t, middleware(msg)) + var pin = &gateway.ChannelPinsUpdateEvent{ + ChannelID: 120, + } + expectBreak(t, middleware(pin)) + var tpg = &gateway.TypingStartEvent{} + expectBreak(t, middleware(tpg)) + }) +} + +func TestGuildOnly(t *testing.T) { + var ctx = &bot.Context{ + State: &state.State{ + Store: &mockStore{}, + }, + } + var middleware = GuildOnly(ctx) + + t.Run("allow message with GuildID", func(t *testing.T) { + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 3, + GuildID: 1337, + }, + } + expectNil(t, middleware(msg)) + }) + + t.Run("allow message with ChannelID", func(t *testing.T) { + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 3, + ChannelID: 69420, + }, + } + expectNil(t, middleware(msg)) + }) + + t.Run("deny message", func(t *testing.T) { + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 1, + ChannelID: 12, + }, + } + expectBreak(t, middleware(msg)) + + var msg2 = &gateway.MessageCreateEvent{} + expectBreak(t, middleware(msg2)) + }) +} + +func expectNil(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal("Unexpected error:", err) + } +} + +func expectBreak(t *testing.T, err error) { + t.Helper() + if errors.Is(err, bot.Break) { + return + } + if err != nil { + t.Fatal("Unexpected error:", err) + } + t.Fatal("Expected error, got nothing.") +} + +// BenchmarkGuildOnly runs a message through the GuildOnly middleware to +// calculate the overhead of reflection. +func BenchmarkGuildOnly(b *testing.B) { + var ctx = &bot.Context{ + State: &state.State{ + Store: &mockStore{}, + }, + } + var middleware = GuildOnly(ctx) + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 3, + GuildID: 1337, + }, + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := middleware(msg); err != nil { + b.Fatal("Unexpected error:", err) + } + } +} + +// BenchmarkAdminOnly runs a message through the GuildOnly middleware to +// calculate the overhead of reflection. +func BenchmarkAdminOnly(b *testing.B) { + var ctx = &bot.Context{ + State: &state.State{ + Store: &mockStore{}, + }, + } + var middleware = AdminOnly(ctx) + var msg = &gateway.MessageCreateEvent{ + Message: discord.Message{ + ID: 1, + ChannelID: 1337, + Author: discord.User{ID: 69420}, + }, + } + + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if err := middleware(msg); err != nil { + b.Fatal("Unexpected error:", err) + } + } +} + +type mockStore struct { + state.NoopStore +} + +func (s *mockStore) Guild(id discord.Snowflake) (*discord.Guild, error) { + return &discord.Guild{ + ID: id, + Roles: []discord.Role{{ + ID: 69420, + Permissions: discord.PermissionAdministrator, + }}, + }, nil +} + +func (s *mockStore) Member(g, m discord.Snowflake) (*discord.Member, error) { + return &discord.Member{ + User: discord.User{ID: m}, + RoleIDs: []discord.Snowflake{m}, + }, nil +} + +// Channel returns a channel with a guildID for #69420. +func (s *mockStore) Channel(chID discord.Snowflake) (*discord.Channel, error) { + if chID == 69420 { + return &discord.Channel{ + ID: chID, + GuildID: 1337, + }, nil + } + + return &discord.Channel{ + ID: chID, + }, nil +} diff --git a/bot/extras/middlewares/test.go b/bot/extras/middlewares/test.go deleted file mode 100644 index 9ff2af6..0000000 --- a/bot/extras/middlewares/test.go +++ /dev/null @@ -1,11 +0,0 @@ -package main - -import "testing" - -func TestAdminOnly(t *testing.T) { - t.Fatal("Do me.") -} - -func TestGuildOnly(t *testing.T) { - t.Fatal("Do me.") -}