mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-04-15 09:44:28 +00:00
Handler: added more tests
This commit is contained in:
parent
1e95e108c8
commit
6da0f0e464
|
@ -40,7 +40,7 @@ type Handler struct {
|
|||
handlers map[uint64]handler
|
||||
horders []uint64
|
||||
hserial uint64
|
||||
hmutex sync.Mutex
|
||||
hmutex sync.RWMutex
|
||||
}
|
||||
|
||||
func New() *Handler {
|
||||
|
@ -53,8 +53,8 @@ func (h *Handler) Call(ev interface{}) {
|
|||
var evV = reflect.ValueOf(ev)
|
||||
var evT = evV.Type()
|
||||
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
h.hmutex.RLock()
|
||||
defer h.hmutex.RUnlock()
|
||||
|
||||
for _, order := range h.horders {
|
||||
handler, ok := h.handlers[order]
|
||||
|
@ -75,7 +75,9 @@ func (h *Handler) Call(ev interface{}) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interface{} {
|
||||
func (h *Handler) WaitFor(
|
||||
ctx context.Context, fn func(interface{}) bool) interface{} {
|
||||
|
||||
var result = make(chan interface{})
|
||||
|
||||
cancel := h.AddHandler(func(v interface{}) {
|
||||
|
@ -83,7 +85,6 @@ func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interf
|
|||
result <- v
|
||||
}
|
||||
})
|
||||
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
|
@ -94,6 +95,24 @@ func (h *Handler) WaitFor(ctx context.Context, fn func(interface{}) bool) interf
|
|||
}
|
||||
}
|
||||
|
||||
func (h *Handler) ChanFor(fn func(interface{}) bool) <-chan interface{} {
|
||||
var result = make(chan interface{})
|
||||
|
||||
cancel := h.AddHandler(func(v interface{}) {
|
||||
if fn(v) {
|
||||
result <- v
|
||||
}
|
||||
})
|
||||
|
||||
var recv = make(chan interface{})
|
||||
go func() {
|
||||
recv <- <-result
|
||||
cancel()
|
||||
}()
|
||||
|
||||
return recv
|
||||
}
|
||||
|
||||
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
|
||||
rm, err := h.addHandler(handler)
|
||||
if err != nil {
|
||||
|
@ -120,9 +139,9 @@ func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
|
|||
return h.addHandler(handler)
|
||||
}
|
||||
|
||||
func (h *Handler) addHandler(handler interface{}) (rm func(), err error) {
|
||||
func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
|
||||
// Reflect the handler
|
||||
r, err := reflectFn(handler)
|
||||
r, err := reflectFn(fn)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Handler reflect failed")
|
||||
}
|
||||
|
@ -134,6 +153,11 @@ func (h *Handler) addHandler(handler interface{}) (rm func(), err error) {
|
|||
serial := h.hserial
|
||||
h.hserial++
|
||||
|
||||
// Create a map if there's none:
|
||||
if h.handlers == nil {
|
||||
h.handlers = map[uint64]handler{}
|
||||
}
|
||||
|
||||
// Use the serial for the map:
|
||||
h.handlers[serial] = *r
|
||||
|
||||
|
@ -175,6 +199,10 @@ func reflectFn(function interface{}) (*handler, error) {
|
|||
return nil, errors.New("function can only accept 1 event as argument")
|
||||
}
|
||||
|
||||
if fnT.NumOut() > 0 {
|
||||
return nil, errors.New("function can't accept returns")
|
||||
}
|
||||
|
||||
argT := fnT.In(0)
|
||||
kind := argT.Kind()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -123,6 +124,96 @@ func TestHandlerInterface(t *testing.T) {
|
|||
t.Fatal("Assertion failed:", recv)
|
||||
}
|
||||
|
||||
func TestHandlerWait(t *testing.T) {
|
||||
inc := make(chan interface{})
|
||||
|
||||
h := New()
|
||||
|
||||
wanted := &gateway.TypingStartEvent{
|
||||
ChannelID: 123456,
|
||||
}
|
||||
|
||||
evs := []interface{}{
|
||||
&gateway.TypingStartEvent{},
|
||||
&gateway.MessageCreateEvent{},
|
||||
&gateway.ChannelDeleteEvent{},
|
||||
wanted,
|
||||
}
|
||||
|
||||
go func() {
|
||||
inc <- h.WaitFor(context.Background(), func(v interface{}) bool {
|
||||
tp, ok := v.(*gateway.TypingStartEvent)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return tp.ChannelID == wanted.ChannelID
|
||||
})
|
||||
}()
|
||||
|
||||
var recv interface{}
|
||||
var done = make(chan struct{})
|
||||
go func() {
|
||||
recv = <-inc
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
for _, ev := range evs {
|
||||
time.Sleep(1)
|
||||
h.Call(ev)
|
||||
}
|
||||
|
||||
<-done
|
||||
if recv != wanted {
|
||||
t.Fatal("Unexpected receive:", recv)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// Test timeout
|
||||
v := h.WaitFor(ctx, func(v interface{}) bool {
|
||||
return false
|
||||
})
|
||||
|
||||
if v != nil {
|
||||
t.Fatal("Unexpected value:", v)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlerChan(t *testing.T) {
|
||||
h := New()
|
||||
|
||||
wanted := &gateway.TypingStartEvent{
|
||||
ChannelID: 123456,
|
||||
}
|
||||
|
||||
evs := []interface{}{
|
||||
&gateway.TypingStartEvent{},
|
||||
&gateway.MessageCreateEvent{},
|
||||
&gateway.ChannelDeleteEvent{},
|
||||
wanted,
|
||||
}
|
||||
|
||||
inc := h.ChanFor(func(v interface{}) bool {
|
||||
tp, ok := v.(*gateway.TypingStartEvent)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return tp.ChannelID == wanted.ChannelID
|
||||
})
|
||||
|
||||
for _, ev := range evs {
|
||||
h.Call(ev)
|
||||
}
|
||||
|
||||
recv := <-inc
|
||||
if recv != wanted {
|
||||
t.Fatal("Unexpected receive:", recv)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReflect(b *testing.B) {
|
||||
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {})
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in a new issue