mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-01-02 18:26:41 +00:00
Added session, started state
This commit is contained in:
parent
4a529dd2ec
commit
d627690835
141
internal/handler/handler.go
Normal file
141
internal/handler/handler.go
Normal file
|
@ -0,0 +1,141 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
// Synchronous controls whether to spawn each event handler in its own
|
||||
// goroutine. Default false (meaning goroutines are spawned).
|
||||
Synchronous bool
|
||||
|
||||
handlers map[uint64]handler
|
||||
hserial uint64
|
||||
hmutex sync.Mutex
|
||||
}
|
||||
|
||||
func New() *Handler {
|
||||
return &Handler{
|
||||
handlers: map[uint64]handler{},
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Call(ev interface{}) {
|
||||
var evV = reflect.ValueOf(ev)
|
||||
var evT = evV.Type()
|
||||
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
|
||||
for _, handler := range h.handlers {
|
||||
if handler.not(evT) {
|
||||
continue
|
||||
}
|
||||
|
||||
if h.Synchronous {
|
||||
handler.call(evV)
|
||||
} else {
|
||||
go handler.call(evV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) AddHandler(handler interface{}) (rm func()) {
|
||||
rm, err := h.addHandler(handler)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return rm
|
||||
}
|
||||
|
||||
// AddHandlerCheck adds the handler, but safe-guards reflect panics with a
|
||||
// recoverer, returning the error.
|
||||
func (h *Handler) AddHandlerCheck(handler interface{}) (rm func(), err error) {
|
||||
// Reflect would actually panic if anything goes wrong, so this is just in
|
||||
// case.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if recErr, ok := rec.(error); ok {
|
||||
err = recErr
|
||||
} else {
|
||||
err = fmt.Errorf("%v", rec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return h.addHandler(handler)
|
||||
}
|
||||
|
||||
func (h *Handler) addHandler(handler interface{}) (rm func(), err error) {
|
||||
// Reflect the handler
|
||||
r, err := reflectFn(handler)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Handler reflect failed")
|
||||
}
|
||||
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
|
||||
// Get the current counter value and increment the counter
|
||||
serial := h.hserial
|
||||
h.hserial++
|
||||
|
||||
// Use the serial for the map
|
||||
h.handlers[serial] = *r
|
||||
|
||||
return func() {
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
|
||||
delete(h.handlers, serial)
|
||||
}, nil
|
||||
}
|
||||
|
||||
type handler struct {
|
||||
event reflect.Type
|
||||
callback reflect.Value
|
||||
isIface bool
|
||||
}
|
||||
|
||||
func reflectFn(function interface{}) (*handler, error) {
|
||||
fnV := reflect.ValueOf(function)
|
||||
fnT := fnV.Type()
|
||||
|
||||
if fnT.Kind() != reflect.Func {
|
||||
return nil, errors.New("given interface is not a function")
|
||||
}
|
||||
|
||||
if fnT.NumIn() != 1 {
|
||||
return nil, errors.New("function can only accept 1 event as argument")
|
||||
}
|
||||
|
||||
argT := fnT.In(0)
|
||||
kind := argT.Kind()
|
||||
|
||||
// Accept either pointer type or interface{} type
|
||||
if kind != reflect.Ptr && kind != reflect.Interface {
|
||||
return nil, errors.New("first argument is not pointer")
|
||||
}
|
||||
|
||||
return &handler{
|
||||
event: argT,
|
||||
callback: fnV,
|
||||
isIface: kind == reflect.Interface,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h handler) not(event reflect.Type) bool {
|
||||
if h.isIface {
|
||||
return !event.Implements(h.event)
|
||||
}
|
||||
|
||||
return h.event != event
|
||||
}
|
||||
|
||||
func (h handler) call(event reflect.Value) {
|
||||
h.callback.Call([]reflect.Value{event})
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
package session
|
||||
package handler
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
@ -7,7 +7,7 @@ import (
|
|||
"github.com/diamondburned/arikawa/gateway"
|
||||
)
|
||||
|
||||
func TestReflect(t *testing.T) {
|
||||
func TestHandler(t *testing.T) {
|
||||
var results = make(chan string)
|
||||
|
||||
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {
|
||||
|
@ -36,6 +36,42 @@ func TestReflect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestHandlerInterface(t *testing.T) {
|
||||
var results = make(chan interface{})
|
||||
|
||||
h, err := reflectFn(func(m interface{}) {
|
||||
results <- m
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
const result = "Hime Arikawa"
|
||||
var msg = &gateway.MessageCreateEvent{
|
||||
Content: result,
|
||||
}
|
||||
|
||||
var msgV = reflect.ValueOf(msg)
|
||||
var msgT = msgV.Type()
|
||||
|
||||
if h.not(msgT) {
|
||||
t.Fatal("Event type mismatch")
|
||||
}
|
||||
|
||||
go h.call(msgV)
|
||||
recv := <-results
|
||||
|
||||
if msg, ok := recv.(*gateway.MessageCreateEvent); ok {
|
||||
if msg.Content == result {
|
||||
return
|
||||
}
|
||||
|
||||
t.Fatal("Content mismatch:", msg.Content)
|
||||
}
|
||||
|
||||
t.Fatal("Assertion failed:", recv)
|
||||
}
|
||||
|
||||
func BenchmarkReflect(b *testing.B) {
|
||||
h, err := reflectFn(func(m *gateway.MessageCreateEvent) {})
|
||||
if err != nil {
|
|
@ -1,43 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type handler struct {
|
||||
event reflect.Type
|
||||
callback reflect.Value
|
||||
}
|
||||
|
||||
func reflectFn(function interface{}) (*handler, error) {
|
||||
fnV := reflect.ValueOf(function)
|
||||
fnT := fnV.Type()
|
||||
|
||||
if fnT.Kind() != reflect.Func {
|
||||
return nil, errors.New("given interface is not a function")
|
||||
}
|
||||
|
||||
if fnT.NumIn() != 1 {
|
||||
return nil, errors.New("function can only accept 1 event as argument")
|
||||
}
|
||||
|
||||
argT := fnT.In(0)
|
||||
|
||||
if argT.Kind() != reflect.Ptr {
|
||||
return nil, errors.New("first argument is not pointer")
|
||||
}
|
||||
|
||||
return &handler{
|
||||
event: argT,
|
||||
callback: fnV,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h handler) not(event reflect.Type) bool {
|
||||
return h.event != event
|
||||
}
|
||||
|
||||
func (h handler) call(event reflect.Value) {
|
||||
h.callback.Call([]reflect.Value{event})
|
||||
}
|
|
@ -1,53 +1,33 @@
|
|||
// Package session abstracts around the REST API and the Gateway, managing both
|
||||
// at once. It offers a handler interface similar to that in discordgo for
|
||||
// Gateway events.
|
||||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/api"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/internal/handler"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
/*
|
||||
TODO:
|
||||
|
||||
and Session's supposed to handle callbacks too kec
|
||||
|
||||
might move all these to Gateway, dunno
|
||||
|
||||
could have a lock on Listen()
|
||||
|
||||
I can actually see people using gateway channels to handle things
|
||||
themselves without any callback abstractions, so this is probably the way to go
|
||||
|
||||
welp shit
|
||||
|
||||
rewrite imminent
|
||||
*/
|
||||
|
||||
type Session struct {
|
||||
*api.Client
|
||||
gateway *gateway.Gateway
|
||||
|
||||
// ErrorLog logs errors, including Gateway errors.
|
||||
ErrorLog func(err error) // default to log.Println
|
||||
|
||||
// Synchronous controls whether to spawn each event handler in its own
|
||||
// goroutine. Default false (meaning goroutines are spawned).
|
||||
Synchronous bool
|
||||
|
||||
// handlers stuff
|
||||
handlers map[uint64]handler
|
||||
hserial uint64
|
||||
hmutex sync.Mutex
|
||||
hstop chan<- struct{}
|
||||
*handler.Handler
|
||||
hstop chan struct{}
|
||||
}
|
||||
|
||||
func New(token string) (*Session, error) {
|
||||
// Initialize the session and the API interface
|
||||
s := &Session{}
|
||||
s.Handler = handler.New()
|
||||
s.Client = api.NewClient(token)
|
||||
|
||||
// Default logger
|
||||
|
@ -61,19 +41,28 @@ func New(token string) (*Session, error) {
|
|||
return nil, errors.Wrap(err, "Failed to connect to Gateway")
|
||||
}
|
||||
s.gateway = g
|
||||
s.gateway.ErrorLog = func(err error) {
|
||||
s.ErrorLog(err)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func NewWithGateway(gw *gateway.Gateway) *Session {
|
||||
return &Session{
|
||||
s := &Session{
|
||||
// Nab off gateway's token
|
||||
Client: api.NewClient(gw.Identifier.Token),
|
||||
ErrorLog: func(err error) {
|
||||
log.Println("Arikawa/session error:", err)
|
||||
},
|
||||
handlers: map[uint64]handler{},
|
||||
Handler: handler.New(),
|
||||
}
|
||||
|
||||
gw.ErrorLog = func(err error) {
|
||||
s.ErrorLog(err)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Session) Open() error {
|
||||
|
@ -88,84 +77,13 @@ func (s *Session) Open() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) AddHandler(handler interface{}) (rm func()) {
|
||||
rm, err := s.addHandler(handler)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return rm
|
||||
}
|
||||
|
||||
// AddHandlerCheck adds the handler, but safe-guards reflect panics with a
|
||||
// recoverer, returning the error.
|
||||
func (s *Session) AddHandlerCheck(handler interface{}) (rm func(), err error) {
|
||||
// Reflect would actually panic if anything goes wrong, so this is just in
|
||||
// case.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if recErr, ok := rec.(error); ok {
|
||||
err = recErr
|
||||
} else {
|
||||
err = fmt.Errorf("%v", rec)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return s.addHandler(handler)
|
||||
}
|
||||
|
||||
func (s *Session) addHandler(handler interface{}) (rm func(), err error) {
|
||||
// Reflect the handler
|
||||
h, err := reflectFn(handler)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Handler reflect failed")
|
||||
}
|
||||
|
||||
s.hmutex.Lock()
|
||||
defer s.hmutex.Unlock()
|
||||
|
||||
// Get the current counter value and increment the counter
|
||||
serial := s.hserial
|
||||
s.hserial++
|
||||
|
||||
// Use the serial for the map
|
||||
s.handlers[serial] = *h
|
||||
|
||||
return func() {
|
||||
s.hmutex.Lock()
|
||||
defer s.hmutex.Unlock()
|
||||
|
||||
delete(s.handlers, serial)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Session) startHandler(stop <-chan struct{}) {
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
return
|
||||
case ev := <-s.gateway.Events:
|
||||
s.call(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) call(ev interface{}) {
|
||||
var evV = reflect.ValueOf(ev)
|
||||
var evT = evV.Type()
|
||||
|
||||
s.hmutex.Lock()
|
||||
defer s.hmutex.Unlock()
|
||||
|
||||
for _, handler := range s.handlers {
|
||||
if handler.not(evT) {
|
||||
continue
|
||||
}
|
||||
|
||||
if s.Synchronous {
|
||||
handler.call(evV)
|
||||
} else {
|
||||
go handler.call(evV)
|
||||
s.Handler.Call(ev)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
65
state/state.go
Normal file
65
state/state.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/diamondburned/arikawa/discord"
|
||||
"github.com/diamondburned/arikawa/gateway"
|
||||
"github.com/diamondburned/arikawa/internal/handler"
|
||||
"github.com/diamondburned/arikawa/session"
|
||||
)
|
||||
|
||||
type State struct {
|
||||
*session.Session
|
||||
|
||||
// PreHandler is the manual hook that is executed before the State handler
|
||||
// is. This should only be used for low-level operations.
|
||||
// It's recommended to set Synchronous to true if you mutate the events.
|
||||
PreHandler *handler.Handler
|
||||
|
||||
guilds []discord.Guild
|
||||
channels []discord.Channel
|
||||
privates []discord.Channel
|
||||
messages map[discord.Snowflake][]discord.Message
|
||||
|
||||
mut sync.Mutex
|
||||
|
||||
unhooker func()
|
||||
}
|
||||
|
||||
func NewFromSession(s *session.Session) (*State, error) {
|
||||
state := &State{
|
||||
Session: s,
|
||||
messages: map[discord.Snowflake][]discord.Message{},
|
||||
}
|
||||
|
||||
return state, state.hookSession()
|
||||
}
|
||||
|
||||
// Unhook removes all state handlers from the session handlers.
|
||||
func (s *State) Unhook() {
|
||||
s.unhooker()
|
||||
}
|
||||
|
||||
// Reset resets the entire state.
|
||||
func (s *State) Reset() {
|
||||
s.mut.Lock()
|
||||
defer s.mut.Unlock()
|
||||
|
||||
panic("IMPLEMENT ME")
|
||||
}
|
||||
|
||||
func (s *State) hookSession() error {
|
||||
s.unhooker = s.Session.AddHandler(func(iface interface{}) {
|
||||
if s.PreHandler != nil {
|
||||
s.PreHandler.Call(iface)
|
||||
}
|
||||
|
||||
switch ev := iface.(type) {
|
||||
case *gateway.MessageCreateEvent:
|
||||
_ = ev
|
||||
panic("IMPLEMENT ME")
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
Loading…
Reference in a new issue