1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2025-04-15 09:44:28 +00:00

Handler: added more tests

This commit is contained in:
diamondburned (Forefront) 2020-02-04 20:30:17 -08:00
parent 1e95e108c8
commit 6da0f0e464
2 changed files with 126 additions and 7 deletions

View file

@ -40,7 +40,7 @@ type Handler struct {
handlers map[uint64]handler
horders []uint64
hserial uint64
hmutex sync.Mutex
hmutex sync.RWMutex
}
func New() *Handler {
@ -53,8 +53,8 @@ func (h *Handler) Call(ev interface{}) {
var evV = reflect.ValueOf(ev)
var evT = evV.Type()
h.hmutex.Lock()
defer h.hmutex.Unlock()
h.hmutex.RLock()
defer h.hmutex.RUnlock()
for _, order := range h.horders {
handler, ok := h.handlers[order]
@ -75,7 +75,9 @@ func (h *Handler) Call(ev interface{}) {
}
}
func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interface{} {
func (h *Handler) WaitFor(
ctx context.Context, fn func(interface{}) bool) interface{} {
var result = make(chan interface{})
cancel := h.AddHandler(func(v interface{}) {
@ -83,7 +85,6 @@ func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interf
result <- v
}
})
defer cancel()
select {
@ -94,6 +95,24 @@ func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interf
}
}
func (h *Handler) ChanFor(fn func(interface{}) bool) <-chan interface{} {
var result = make(chan interface{})
cancel := h.AddHandler(func(v interface{}) {
if fn(v) {
result <- v
}
})
var recv = make(chan interface{})
go func() {
recv <- <-result
cancel()
}()
return recv
}
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
rm, err := h.addHandler(handler)
if err != nil {
@ -120,9 +139,9 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
return h.addHandler(handler)
}
func (h *Handler) addHandler(handler interface{}) (rm func(), err error) {
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
// Reflect the handler
r, err := reflectFn(handler)
r, err := reflectFn(fn)
if err != nil {
return nil, errors.Wrap(err, "Handler reflect failed")
}
@ -134,6 +153,11 @@ func (h *Handler) addHandler(handler interface{}) (rm func(), err error) {
serial := h.hserial
h.hserial++
// Create a map if there's none:
if h.handlers == nil {
h.handlers = map[uint64]handler{}
}
// Use the serial for the map:
h.handlers[serial] = *r
@ -175,6 +199,10 @@ func reflectFn(function interface{}) (*handler, error) {
return nil, errors.New("function can only accept 1 event as argument")
}
if fnT.NumOut() > 0 {
return nil, errors.New("function can't accept returns")
}
argT := fnT.In(0)
kind := argT.Kind()

View file

@ -3,6 +3,7 @@
package handler
import (
"context"
"reflect"
"strings"
"testing"
@ -123,6 +124,96 @@ func TestHandlerInterface(t *testing.T) {
t.Fatal("Assertion failed:", recv)
}
func TestHandlerWait(t *testing.T) {
inc := make(chan interface{})
h := New()
wanted := &gateway.TypingStartEvent{
ChannelID: 123456,
}
evs := []interface{}{
&gateway.TypingStartEvent{},
&gateway.MessageCreateEvent{},
&gateway.ChannelDeleteEvent{},
wanted,
}
go func() {
inc <- h.WaitFor(context.Background(), func(v interface{}) bool {
tp, ok := v.(*gateway.TypingStartEvent)
if !ok {
return false
}
return tp.ChannelID == wanted.ChannelID
})
}()
var recv interface{}
var done = make(chan struct{})
go func() {
recv = <-inc
done <- struct{}{}
}()
for _, ev := range evs {
time.Sleep(1)
h.Call(ev)
}
<-done
if recv != wanted {
t.Fatal("Unexpected receive:", recv)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
// Test timeout
v := h.WaitFor(ctx, func(v interface{}) bool {
return false
})
if v != nil {
t.Fatal("Unexpected value:", v)
}
}
func TestHandlerChan(t *testing.T) {
h := New()
wanted := &gateway.TypingStartEvent{
ChannelID: 123456,
}
evs := []interface{}{
&gateway.TypingStartEvent{},
&gateway.MessageCreateEvent{},
&gateway.ChannelDeleteEvent{},
wanted,
}
inc := h.ChanFor(func(v interface{}) bool {
tp, ok := v.(*gateway.TypingStartEvent)
if !ok {
return false
}
return tp.ChannelID == wanted.ChannelID
})
for _, ev := range evs {
h.Call(ev)
}
recv := <-inc
if recv != wanted {
t.Fatal("Unexpected receive:", recv)
}
}
func BenchmarkReflect(b *testing.B) {
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {})
if err != nil {