mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-09 21:47:07 +00:00
224 lines
4.6 KiB
Go
224 lines
4.6 KiB
Go
package middlewares
|
|
|
|
import (
|
|
"errors"
|
|
"testing"
|
|
|
|
"github.com/diamondburned/arikawa/v2/bot"
|
|
"github.com/diamondburned/arikawa/v2/discord"
|
|
"github.com/diamondburned/arikawa/v2/gateway"
|
|
"github.com/diamondburned/arikawa/v2/session"
|
|
"github.com/diamondburned/arikawa/v2/state"
|
|
"github.com/diamondburned/arikawa/v2/state/store"
|
|
)
|
|
|
|
func TestAdminOnly(t *testing.T) {
|
|
var ctx = &bot.Context{
|
|
State: &state.State{
|
|
Session: &session.Session{
|
|
Gateway: &gateway.Gateway{
|
|
Identifier: &gateway.Identifier{
|
|
IdentifyData: gateway.IdentifyData{
|
|
Intents: gateway.IntentGuilds | gateway.IntentGuildMembers,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Cabinet: mockCabinet(),
|
|
},
|
|
}
|
|
var middleware = AdminOnly(ctx)
|
|
|
|
t.Run("allow message", func(t *testing.T) {
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 1,
|
|
ChannelID: 69420,
|
|
Author: discord.User{ID: 69420},
|
|
},
|
|
}
|
|
expectNil(t, middleware(msg))
|
|
})
|
|
|
|
t.Run("deny message", func(t *testing.T) {
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 2,
|
|
ChannelID: 1337,
|
|
Author: discord.User{ID: 1337},
|
|
},
|
|
}
|
|
expectBreak(t, middleware(msg))
|
|
var pin = &gateway.ChannelPinsUpdateEvent{
|
|
ChannelID: 120,
|
|
}
|
|
expectBreak(t, middleware(pin))
|
|
var tpg = &gateway.TypingStartEvent{}
|
|
expectBreak(t, middleware(tpg))
|
|
})
|
|
}
|
|
|
|
func TestGuildOnly(t *testing.T) {
|
|
var ctx = &bot.Context{
|
|
State: &state.State{
|
|
Session: &session.Session{
|
|
Gateway: &gateway.Gateway{
|
|
Identifier: &gateway.Identifier{
|
|
IdentifyData: gateway.IdentifyData{
|
|
Intents: gateway.IntentGuilds,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
Cabinet: mockCabinet(),
|
|
},
|
|
}
|
|
var middleware = GuildOnly(ctx)
|
|
|
|
t.Run("allow message with GuildID", func(t *testing.T) {
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 3,
|
|
GuildID: 1337,
|
|
},
|
|
}
|
|
expectNil(t, middleware(msg))
|
|
})
|
|
|
|
t.Run("allow message with ChannelID", func(t *testing.T) {
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 3,
|
|
ChannelID: 69420,
|
|
},
|
|
}
|
|
expectNil(t, middleware(msg))
|
|
})
|
|
|
|
t.Run("deny message", func(t *testing.T) {
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 1,
|
|
ChannelID: 12,
|
|
},
|
|
}
|
|
expectBreak(t, middleware(msg))
|
|
|
|
var msg2 = &gateway.MessageCreateEvent{}
|
|
expectBreak(t, middleware(msg2))
|
|
})
|
|
}
|
|
|
|
func expectNil(t *testing.T, err error) {
|
|
t.Helper()
|
|
if err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
}
|
|
|
|
func expectBreak(t *testing.T, err error) {
|
|
t.Helper()
|
|
if errors.Is(err, bot.Break) {
|
|
return
|
|
}
|
|
if err != nil {
|
|
t.Fatal("Unexpected error:", err)
|
|
}
|
|
t.Fatal("Expected error, got nothing.")
|
|
}
|
|
|
|
// BenchmarkGuildOnly runs a message through the GuildOnly middleware to
|
|
// calculate the overhead of reflection.
|
|
func BenchmarkGuildOnly(b *testing.B) {
|
|
var ctx = &bot.Context{
|
|
State: &state.State{
|
|
Cabinet: mockCabinet(),
|
|
},
|
|
}
|
|
var middleware = GuildOnly(ctx)
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 3,
|
|
GuildID: 1337,
|
|
},
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
if err := middleware(msg); err != nil {
|
|
b.Fatal("Unexpected error:", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// BenchmarkAdminOnly runs a message through the GuildOnly middleware to
|
|
// calculate the overhead of reflection.
|
|
func BenchmarkAdminOnly(b *testing.B) {
|
|
var ctx = &bot.Context{
|
|
State: &state.State{
|
|
Cabinet: mockCabinet(),
|
|
},
|
|
}
|
|
var middleware = AdminOnly(ctx)
|
|
var msg = &gateway.MessageCreateEvent{
|
|
Message: discord.Message{
|
|
ID: 1,
|
|
ChannelID: 1337,
|
|
Author: discord.User{ID: 69420},
|
|
},
|
|
}
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
if err := middleware(msg); err != nil {
|
|
b.Fatal("Unexpected error:", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
type mockStore struct {
|
|
store.NoopStore
|
|
}
|
|
|
|
func mockCabinet() store.Cabinet {
|
|
c := store.NoopCabinet
|
|
c.GuildStore = &mockStore{}
|
|
c.MemberStore = &mockStore{}
|
|
c.ChannelStore = &mockStore{}
|
|
|
|
return c
|
|
}
|
|
|
|
func (s *mockStore) Guild(id discord.GuildID) (*discord.Guild, error) {
|
|
return &discord.Guild{
|
|
ID: id,
|
|
Roles: []discord.Role{{
|
|
ID: 69420,
|
|
Permissions: discord.PermissionAdministrator,
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
func (s *mockStore) Member(_ discord.GuildID, userID discord.UserID) (*discord.Member, error) {
|
|
return &discord.Member{
|
|
User: discord.User{ID: userID},
|
|
RoleIDs: []discord.RoleID{discord.RoleID(userID)},
|
|
}, nil
|
|
}
|
|
|
|
// Channel returns a channel with a guildID for #69420.
|
|
func (s *mockStore) Channel(id discord.ChannelID) (*discord.Channel, error) {
|
|
if id == 69420 {
|
|
return &discord.Channel{
|
|
ID: id,
|
|
GuildID: 1337,
|
|
}, nil
|
|
}
|
|
|
|
return &discord.Channel{
|
|
ID: id,
|
|
}, nil
|
|
}
|