mirror of
https://github.com/diamondburned/arikawa.git
synced 2024-12-02 11:52:56 +00:00
Added Session + Handlers
This commit is contained in:
parent
9df384bf32
commit
4a529dd2ec
43
session/handler.go
Normal file
43
session/handler.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
type handler struct {
|
||||||
|
event reflect.Type
|
||||||
|
callback reflect.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func reflectFn(function interface{}) (*handler, error) {
|
||||||
|
fnV := reflect.ValueOf(function)
|
||||||
|
fnT := fnV.Type()
|
||||||
|
|
||||||
|
if fnT.Kind() != reflect.Func {
|
||||||
|
return nil, errors.New("given interface is not a function")
|
||||||
|
}
|
||||||
|
|
||||||
|
if fnT.NumIn() != 1 {
|
||||||
|
return nil, errors.New("function can only accept 1 event as argument")
|
||||||
|
}
|
||||||
|
|
||||||
|
argT := fnT.In(0)
|
||||||
|
|
||||||
|
if argT.Kind() != reflect.Ptr {
|
||||||
|
return nil, errors.New("first argument is not pointer")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &handler{
|
||||||
|
event: argT,
|
||||||
|
callback: fnV,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h handler) not(event reflect.Type) bool {
|
||||||
|
return h.event != event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h handler) call(event reflect.Value) {
|
||||||
|
h.callback.Call([]reflect.Value{event})
|
||||||
|
}
|
59
session/handler_test.go
Normal file
59
session/handler_test.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/gateway"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReflect(t *testing.T) {
|
||||||
|
var results = make(chan string)
|
||||||
|
|
||||||
|
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {
|
||||||
|
results <- m.Content
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = "Hime Arikawa"
|
||||||
|
var msg = &gateway.MessageCreateEvent{
|
||||||
|
Content: result,
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgV = reflect.ValueOf(msg)
|
||||||
|
var msgT = msgV.Type()
|
||||||
|
|
||||||
|
if h.not(msgT) {
|
||||||
|
t.Fatal("Event type mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
go h.call(msgV)
|
||||||
|
|
||||||
|
if results := <-results; results != result {
|
||||||
|
t.Fatal("Unexpected results:", results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkReflect(b *testing.B) {
|
||||||
|
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg = &gateway.MessageCreateEvent{}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for n := 0; n < b.N; n++ {
|
||||||
|
var msgV = reflect.ValueOf(msg)
|
||||||
|
var msgT = msgV.Type()
|
||||||
|
|
||||||
|
if h.not(msgT) {
|
||||||
|
b.Fatal("Event type mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.call(msgV)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,13 +1,14 @@
|
||||||
package session
|
package session
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"reflect"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/diamondburned/arikawa/api"
|
"github.com/diamondburned/arikawa/api"
|
||||||
"github.com/diamondburned/arikawa/gateway"
|
"github.com/diamondburned/arikawa/gateway"
|
||||||
"github.com/diamondburned/arikawa/internal/json"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -28,56 +29,153 @@ import (
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
API *api.Client
|
*api.Client
|
||||||
Gateway *gateway.Conn
|
gateway *gateway.Gateway
|
||||||
gatewayOnce sync.Once
|
|
||||||
|
|
||||||
ErrorLog func(err error) // default to log.Println
|
ErrorLog func(err error) // default to log.Println
|
||||||
|
|
||||||
// Heartrate is the received duration between heartbeats.
|
// Synchronous controls whether to spawn each event handler in its own
|
||||||
Heartrate time.Duration
|
// goroutine. Default false (meaning goroutines are spawned).
|
||||||
|
Synchronous bool
|
||||||
|
|
||||||
// LastBeat logs the received heartbeats, with the newest one
|
// handlers stuff
|
||||||
// first.
|
handlers map[uint64]handler
|
||||||
LastBeat [2]time.Time
|
hserial uint64
|
||||||
|
hmutex sync.Mutex
|
||||||
// Used for Close()
|
hstop chan<- struct{}
|
||||||
stoppers []chan<- struct{}
|
|
||||||
closers []func() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(token string) (*Session, error) {
|
func New(token string) (*Session, error) {
|
||||||
// Initialize the session and the API interface
|
// Initialize the session and the API interface
|
||||||
s := &Session{}
|
s := &Session{}
|
||||||
s.API = api.NewClient(token)
|
s.Client = api.NewClient(token)
|
||||||
|
|
||||||
// Default logger
|
// Default logger
|
||||||
s.ErrorLog = func(err error) {
|
s.ErrorLog = func(err error) {
|
||||||
log.Println("Arikawa/session error:", err)
|
log.Println("Arikawa/session error:", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the Gateway
|
// Open a gateway
|
||||||
c, err := gateway.NewConn(json.Default{})
|
g, err := gateway.NewGateway(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "Failed to connect to Gateway")
|
||||||
}
|
}
|
||||||
s.Gateway = c
|
s.gateway = g
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) Close() error {
|
func NewWithGateway(gw *gateway.Gateway) *Session {
|
||||||
for _, stop := range s.stoppers {
|
return &Session{
|
||||||
close(stop)
|
// Nab off gateway's token
|
||||||
|
Client: api.NewClient(gw.Identifier.Token),
|
||||||
|
ErrorLog: func(err error) {
|
||||||
|
log.Println("Arikawa/session error:", err)
|
||||||
|
},
|
||||||
|
handlers: map[uint64]handler{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Open() error {
|
||||||
|
if err := s.gateway.Start(); err != nil {
|
||||||
|
return errors.Wrap(err, "Failed to start gateway")
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
stop := make(chan struct{})
|
||||||
|
s.hstop = stop
|
||||||
|
go s.startHandler(stop)
|
||||||
|
|
||||||
for _, closer := range s.closers {
|
return nil
|
||||||
if cerr := closer(); cerr != nil {
|
}
|
||||||
err = cerr
|
|
||||||
|
func (s *Session) AddHandler(handler interface{}) (rm func()) {
|
||||||
|
rm, err := s.addHandler(handler)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return rm
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHandlerCheck adds the handler, but safe-guards reflect panics with a
|
||||||
|
// recoverer, returning the error.
|
||||||
|
func (s *Session) AddHandlerCheck(handler interface{}) (rm func(), err error) {
|
||||||
|
// Reflect would actually panic if anything goes wrong, so this is just in
|
||||||
|
// case.
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
if recErr, ok := rec.(error); ok {
|
||||||
|
err = recErr
|
||||||
|
} else {
|
||||||
|
err = fmt.Errorf("%v", rec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return s.addHandler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) addHandler(handler interface{}) (rm func(), err error) {
|
||||||
|
// Reflect the handler
|
||||||
|
h, err := reflectFn(handler)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "Handler reflect failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.hmutex.Lock()
|
||||||
|
defer s.hmutex.Unlock()
|
||||||
|
|
||||||
|
// Get the current counter value and increment the counter
|
||||||
|
serial := s.hserial
|
||||||
|
s.hserial++
|
||||||
|
|
||||||
|
// Use the serial for the map
|
||||||
|
s.handlers[serial] = *h
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
s.hmutex.Lock()
|
||||||
|
defer s.hmutex.Unlock()
|
||||||
|
|
||||||
|
delete(s.handlers, serial)
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) startHandler(stop <-chan struct{}) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
case ev := <-s.gateway.Events:
|
||||||
|
s.call(ev)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return err
|
|
||||||
|
func (s *Session) call(ev interface{}) {
|
||||||
|
var evV = reflect.ValueOf(ev)
|
||||||
|
var evT = evV.Type()
|
||||||
|
|
||||||
|
s.hmutex.Lock()
|
||||||
|
defer s.hmutex.Unlock()
|
||||||
|
|
||||||
|
for _, handler := range s.handlers {
|
||||||
|
if handler.not(evT) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Synchronous {
|
||||||
|
handler.call(evV)
|
||||||
|
} else {
|
||||||
|
go handler.call(evV)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) Close() error {
|
||||||
|
// Stop the event handler
|
||||||
|
if s.hstop != nil {
|
||||||
|
close(s.hstop)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the websocket
|
||||||
|
return s.gateway.Close()
|
||||||
}
|
}
|
||||||
|
|
55
session/session_test.go
Normal file
55
session/session_test.go
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/diamondburned/arikawa/gateway"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSessionCall(t *testing.T) {
|
||||||
|
var results = make(chan string)
|
||||||
|
|
||||||
|
s := &Session{
|
||||||
|
handlers: map[uint64]handler{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add handler test
|
||||||
|
rm := s.AddHandler(func(m *gateway.MessageCreateEvent) {
|
||||||
|
results <- m.Content
|
||||||
|
})
|
||||||
|
|
||||||
|
go s.call(&gateway.MessageCreateEvent{
|
||||||
|
Content: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
if r := <-results; r != "test" {
|
||||||
|
t.Fatal("Returned results is wrong:", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove handler test
|
||||||
|
rm()
|
||||||
|
|
||||||
|
go s.call(&gateway.MessageCreateEvent{
|
||||||
|
Content: "test",
|
||||||
|
})
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-results:
|
||||||
|
t.Fatal("Unexpected results")
|
||||||
|
case <-time.After(time.Millisecond):
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid type test
|
||||||
|
rm, err := s.AddHandlerCheck("this should panic")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("No errors found")
|
||||||
|
}
|
||||||
|
defer rm()
|
||||||
|
|
||||||
|
if !strings.Contains(err.Error(), "given interface is not a function") {
|
||||||
|
t.Fatal("Unexpected error:", err)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in a new issue