1
0
Fork 0
mirror of https://github.com/diamondburned/arikawa.git synced 2024-11-30 18:53:30 +00:00

Handler: Added blocking send cleanup to avoid goroutine leak

This commit is contained in:
diamondburned 2020-07-15 23:11:20 -07:00
parent 6717f8002c
commit 35e143a99f
2 changed files with 75 additions and 20 deletions

View file

@ -173,6 +173,9 @@ func (h *Handler) ChanFor(fn func(interface{}) bool) (out <-chan interface{}, ca
// should not be closed at all. The caller function WILL PANIC if the channel is
// closed!
//
// When the rm callback that is returned is called, it will also guarantee that
// all blocking sends will be cancelled. This helps prevent dangling goroutines.
//
// // An example of a valid channel handler.
// ch := make(chan *gateway.MessageCreateEvent)
// h.AddHandler(ch)
@ -232,7 +235,13 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
h.hmutex.Lock()
defer h.hmutex.Unlock()
// Delete the handler from the map:
// Take and delete the handler from the map, but return if we can't find
// it.
hd, ok := h.handlers[serial]
if !ok {
return
}
delete(h.handlers, serial)
// Delete the key from the orders slice:
@ -242,14 +251,17 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
break
}
}
// Clean up the handler.
hd.cleanup()
}, nil
}
type handler struct {
event reflect.Type // underlying type; arg0 or chan underlying type
callback reflect.Value
isChan bool
isIface bool
event reflect.Type // underlying type; arg0 or chan underlying type
callback reflect.Value
isIface bool
chanclose reflect.Value // IsValid() if chan
}
// newHandler reflects either a channel or a function into a handler. A function
@ -260,8 +272,9 @@ func newHandler(unknown interface{}) (*handler, error) {
fnT := fnV.Type()
// underlying event type
var argT reflect.Type
var isch bool
var handler = handler{
callback: fnV,
}
switch fnT.Kind() {
case reflect.Func:
@ -273,29 +286,26 @@ func newHandler(unknown interface{}) (*handler, error) {
return nil, errors.New("function can't accept returns")
}
argT = fnT.In(0)
handler.event = fnT.In(0)
case reflect.Chan:
argT = fnT.Elem()
isch = true
handler.event = fnT.Elem()
handler.chanclose = reflect.ValueOf(make(chan struct{}))
default:
return nil, errors.New("given interface is not a function or channel")
}
var kind = argT.Kind()
var kind = handler.event.Kind()
// Accept either pointer type or interface{} type
if kind != reflect.Ptr && kind != reflect.Interface {
return nil, errors.New("first argument is not pointer")
}
return &handler{
event: argT,
callback: fnV,
isChan: isch,
isIface: kind == reflect.Interface,
}, nil
handler.isIface = kind == reflect.Interface
return &handler, nil
}
func (h handler) not(event reflect.Type) bool {
@ -306,10 +316,21 @@ func (h handler) not(event reflect.Type) bool {
return h.event != event
}
func (h handler) call(event reflect.Value) {
if h.isChan {
h.callback.Send(event)
func (h *handler) call(event reflect.Value) {
if h.chanclose.IsValid() {
reflect.Select([]reflect.SelectCase{
{Dir: reflect.SelectSend, Chan: h.callback, Send: event},
{Dir: reflect.SelectRecv, Chan: h.chanclose},
})
} else {
h.callback.Call([]reflect.Value{event})
}
}
func (h *handler) cleanup() {
if h.chanclose.IsValid() {
// Closing this channel will force all ongoing selects to return
// immediately.
h.chanclose.Close()
}
}

View file

@ -112,6 +112,40 @@ func TestHandlerChan(t *testing.T) {
}
}
func TestHandlerChanCancel(t *testing.T) {
// Never receive from this channel.
var results = make(chan *gateway.MessageCreateEvent)
h, err := newHandler(results)
if err != nil {
t.Fatal(err)
}
const result = "Hime Arikawa"
var msg = newMessage(result)
var msgV = reflect.ValueOf(msg)
var msgT = msgV.Type()
if h.not(msgT) {
t.Fatal("Event type mismatch")
}
// Call in a goroutine, which would trigger a close.
go h.call(msgV)
// Call the cleanup function, which should stop the send.
h.cleanup()
// Check if we still have things being sent.
select {
case <-results:
t.Fatal("Unexpected dangling goroutine")
case <-time.After(200 * time.Millisecond):
return
}
}
func TestHandlerInterface(t *testing.T) {
var results = make(chan interface{})