1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-01-24 21:46:54 +00:00
arikawa/bot/extras/middlewares/middlewares_test.go

224 lines
4.7 KiB
Go

package middlewares
import (
"errors"
"testing"
"github.com/diamondburned/arikawa/v3/bot"
"github.com/diamondburned/arikawa/v3/discord"
"github.com/diamondburned/arikawa/v3/gateway"
"github.com/diamondburned/arikawa/v3/session"
"github.com/diamondburned/arikawa/v3/state"
"github.com/diamondburned/arikawa/v3/state/store"
)
func TestAdminOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: gateway.IntentGuilds | gateway.IntentGuildMembers,
},
},
},
},
Cabinet: mockCabinet(),
},
}
var middleware = AdminOnly(ctx)
t.Run("allow message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 69420,
Author: discord.User{ID: 69420},
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 2,
ChannelID: 1337,
Author: discord.User{ID: 1337},
},
}
expectBreak(t, middleware(msg))
var pin = &gateway.ChannelPinsUpdateEvent{
ChannelID: 120,
}
expectBreak(t, middleware(pin))
var tpg = &gateway.TypingStartEvent{}
expectBreak(t, middleware(tpg))
})
}
func TestGuildOnly(t *testing.T) {
var ctx = &bot.Context{
State: &state.State{
Session: &session.Session{
Gateway: &gateway.Gateway{
Identifier: &gateway.Identifier{
IdentifyData: gateway.IdentifyData{
Intents: gateway.IntentGuilds,
},
},
},
},
Cabinet: mockCabinet(),
},
}
var middleware = GuildOnly(ctx)
t.Run("allow message with GuildID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
expectNil(t, middleware(msg))
})
t.Run("allow message with ChannelID", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
ChannelID: 69420,
},
}
expectNil(t, middleware(msg))
})
t.Run("deny message", func(t *testing.T) {
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 12,
},
}
expectBreak(t, middleware(msg))
var msg2 = &gateway.MessageCreateEvent{}
expectBreak(t, middleware(msg2))
})
}
func expectNil(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal("Unexpected error:", err)
}
}
func expectBreak(t *testing.T, err error) {
t.Helper()
if errors.Is(err, bot.Break) {
return
}
if err != nil {
t.Fatal("Unexpected error:", err)
}
t.Fatal("Expected error, got nothing.")
}
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkGuildOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Cabinet: mockCabinet(),
},
}
var middleware = GuildOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 3,
GuildID: 1337,
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
// calculate the overhead of reflection.
func BenchmarkAdminOnly(b *testing.B) {
var ctx = &bot.Context{
State: &state.State{
Cabinet: mockCabinet(),
},
}
var middleware = AdminOnly(ctx)
var msg = &gateway.MessageCreateEvent{
Message: discord.Message{
ID: 1,
ChannelID: 1337,
Author: discord.User{ID: 69420},
},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := middleware(msg); err != nil {
b.Fatal("Unexpected error:", err)
}
}
}
type mockStore struct {
store.NoopStore
}
func mockCabinet() *store.Cabinet {
c := *store.NoopCabinet
c.GuildStore = &mockStore{}
c.MemberStore = &mockStore{}
c.ChannelStore = &mockStore{}
return &c
}
func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
return &discord.Guild{
ID: id,
Roles: []discord.Role{{
ID: 69420,
Permissions: discord.PermissionAdministrator,
}},
}, nil
}
func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) {
return &discord.Member{
User: discord.User{ID: userID},
RoleIDs: []discord.RoleID{discord.RoleID(userID)},
}, nil
}
// Channel returns a channel with a guildID for #69420.
func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
if id == 69420 {
return &discord.Channel{
ID: id,
GuildID: 1337,
}, nil
}
return &discord.Channel{
ID: id,
}, nil
}