diff --git a/handler/handler.go b/handler/handler.go index 068d1fc..5c7e56b 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -173,6 +173,9 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca // should not be closed at all. The caller function WILL PANIC if the channel is // closed! // +// When the rm callback that is returned is called, it will also guarantee that +// all blocking sends will be cancelled. This helps prevent dangling goroutines. +// // // An example of a valid channel handler. // ch := make(chan *gateway.MessageCreateEvent) // h.AddHandler(ch) @@ -232,7 +235,13 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { h.hmutex.Lock() defer h.hmutex.Unlock() - // Delete the handler from the map: + // Take and delete the handler from the map, but return if we can't find + // it. + hd, ok := h.handlers[serial] + if !ok { + return + } + delete(h.handlers, serial) // Delete the key from the orders slice: @@ -242,14 +251,17 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) { break } } + + // Clean up the handler. + hd.cleanup() }, nil } type handler struct { - event reflect.Type // underlying type; arg0 or chan underlying type - callback reflect.Value - isChan bool - isIface bool + event reflect.Type // underlying type; arg0 or chan underlying type + callback reflect.Value + isIface bool + chanclose reflect.Value // IsValid() if chan } // newHandler reflects either a channel or a function into a handler. A function @@ -260,8 +272,9 @@ func newHandler(unknown interface{}) (*handler, error) { fnT := fnV.Type() // underlying event type - var argT reflect.Type - var isch bool + var handler = handler{ + callback: fnV, + } switch fnT.Kind() { case reflect.Func: @@ -273,29 +286,26 @@ func newHandler(unknown interface{}) (*handler, error) { return nil, errors.New("function can't accept returns") } - argT = fnT.In(0) + handler.event = fnT.In(0) case reflect.Chan: - argT = fnT.Elem() - isch = true + handler.event = fnT.Elem() + handler.chanclose = reflect.ValueOf(make(chan struct{})) default: return nil, errors.New("given interface is not a function or channel") } - var kind = argT.Kind() + var kind = handler.event.Kind() // Accept either pointer type or interface{} type if kind != reflect.Ptr && kind != reflect.Interface { return nil, errors.New("first argument is not pointer") } - return &handler{ - event: argT, - callback: fnV, - isChan: isch, - isIface: kind == reflect.Interface, - }, nil + handler.isIface = kind == reflect.Interface + + return &handler, nil } func (h handler) not(event reflect.Type) bool { @@ -306,10 +316,21 @@ func (h handler) not(event reflect.Type) bool { return h.event != event } -func (h handler) call(event reflect.Value) { - if h.isChan { - h.callback.Send(event) +func (h *handler) call(event reflect.Value) { + if h.chanclose.IsValid() { + reflect.Select([]reflect.SelectCase{ + {Dir: reflect.SelectSend, Chan: h.callback, Send: event}, + {Dir: reflect.SelectRecv, Chan: h.chanclose}, + }) } else { h.callback.Call([]reflect.Value{event}) } } + +func (h *handler) cleanup() { + if h.chanclose.IsValid() { + // Closing this channel will force all ongoing selects to return + // immediately. + h.chanclose.Close() + } +} diff --git a/handler/handler_test.go b/handler/handler_test.go index 19c93ec..d92c27f 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -112,6 +112,40 @@ func TestHandlerChan(t *testing.T) { } } +func TestHandlerChanCancel(t *testing.T) { + // Never receive from this channel. + var results = make(chan *gateway.MessageCreateEvent) + + h, err := newHandler(results) + if err != nil { + t.Fatal(err) + } + + const result = "Hime Arikawa" + var msg = newMessage(result) + + var msgV = reflect.ValueOf(msg) + var msgT = msgV.Type() + + if h.not(msgT) { + t.Fatal("Event type mismatch") + } + + // Call in a goroutine, which would trigger a close. + go h.call(msgV) + + // Call the cleanup function, which should stop the send. + h.cleanup() + + // Check if we still have things being sent. + select { + case <-results: + t.Fatal("Unexpected dangling goroutine") + case <-time.After(200 * time.Millisecond): + return + } +} + func TestHandlerInterface(t *testing.T) { var results = make(chan interface{})