mirror of
https://github.com/diamondburned/arikawa.git
synced 2025-07-29 00:42:16 +00:00
Compare commits
9 commits
75df94d9f4
...
9b4b707070
Author | SHA1 | Date | |
---|---|---|---|
|
9b4b707070 | ||
|
9899f6073b | ||
|
ef48d686cd | ||
|
b7b8118d0b | ||
|
b8e4b2cf56 | ||
|
c00d31ce30 | ||
|
fd16db1385 | ||
|
33e7abd4db | ||
|
160a4e6606 |
|
@ -2,8 +2,24 @@ package bot
|
|||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/diamondburned/arikawa/v2/gateway"
|
||||
)
|
||||
|
||||
// eventIntents maps event pointer types to intents.
|
||||
var eventIntents = map[reflect.Type]gateway.Intents{}
|
||||
|
||||
func init() {
|
||||
for event, intent := range gateway.EventIntents {
|
||||
fn, ok := gateway.EventCreator[event]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
eventIntents[reflect.TypeOf(fn())] = intent
|
||||
}
|
||||
}
|
||||
|
||||
type command struct {
|
||||
value reflect.Value // Func
|
||||
event reflect.Type
|
||||
|
@ -26,6 +42,15 @@ func (c *command) call(arg0 interface{}, argv ...reflect.Value) (interface{}, er
|
|||
return callWith(c.value, arg0, argv...)
|
||||
}
|
||||
|
||||
// intents returns the command's intents from the event.
|
||||
func (c *command) intents() gateway.Intents {
|
||||
intents, ok := eventIntents[c.event]
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
return intents
|
||||
}
|
||||
|
||||
func callWith(caller reflect.Value, arg0 interface{}, argv ...reflect.Value) (interface{}, error) {
|
||||
var callargs = make([]reflect.Value, 0, 1+len(argv))
|
||||
|
||||
|
|
18
bot/ctx.go
18
bot/ctx.go
|
@ -165,6 +165,8 @@ func Start(
|
|||
}
|
||||
}
|
||||
|
||||
s.Gateway.AddIntents(c.DeriveIntents())
|
||||
|
||||
cancel := c.Start()
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
|
@ -229,10 +231,10 @@ func New(s *state.State, cmd interface{}) (*Context, error) {
|
|||
return ctx, nil
|
||||
}
|
||||
|
||||
// AddIntent adds the given Gateway Intent into the Gateway. This is a
|
||||
// AddIntents adds the given Gateway Intent into the Gateway. This is a
|
||||
// convenient function that calls Gateway's AddIntent.
|
||||
func (ctx *Context) AddIntent(i gateway.Intents) {
|
||||
ctx.Gateway.AddIntent(i)
|
||||
func (ctx *Context) AddIntents(i gateway.Intents) {
|
||||
ctx.Gateway.AddIntents(i)
|
||||
}
|
||||
|
||||
// Subcommands returns the slice of subcommands. To add subcommands, use
|
||||
|
@ -444,3 +446,13 @@ func IndentLines(input string) string {
|
|||
}
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
// DeriveIntents derives all possible gateway intents from this context and all
|
||||
// its subcommands' method handlers and middlewares.
|
||||
func (ctx *Context) DeriveIntents() gateway.Intents {
|
||||
var intents = ctx.Subcommand.DeriveIntents()
|
||||
for _, subcmd := range ctx.subcommands {
|
||||
intents |= subcmd.DeriveIntents()
|
||||
}
|
||||
return intents
|
||||
}
|
||||
|
|
|
@ -149,6 +149,21 @@ func TestContext(t *testing.T) {
|
|||
}
|
||||
})
|
||||
|
||||
t.Run("derive intents", func(t *testing.T) {
|
||||
intents := ctx.DeriveIntents()
|
||||
|
||||
assertIntents := func(target gateway.Intents, name string) {
|
||||
if !intents.Has(target) {
|
||||
t.Error("Derived intents do not have", name)
|
||||
}
|
||||
}
|
||||
|
||||
assertIntents(gateway.IntentGuildMessages, "guild messages")
|
||||
assertIntents(gateway.IntentDirectMessages, "direct messages")
|
||||
assertIntents(gateway.IntentGuildMessageTyping, "guild typing")
|
||||
assertIntents(gateway.IntentDirectMessageTyping, "direct message typing")
|
||||
})
|
||||
|
||||
t.Run("typing event", func(t *testing.T) {
|
||||
typing := &gateway.TypingStartEvent{}
|
||||
|
||||
|
|
|
@ -143,6 +143,10 @@ func (sub *Subcommand) NeedsName() {
|
|||
sub.Command = lowerFirstLetter(sub.StructName)
|
||||
}
|
||||
|
||||
func lowerFirstLetter(name string) string {
|
||||
return strings.ToLower(string(name[0])) + name[1:]
|
||||
}
|
||||
|
||||
// FindCommand finds the MethodContext. It panics if methodName is not found.
|
||||
func (sub *Subcommand) FindCommand(methodName string) *MethodContext {
|
||||
for _, c := range sub.Commands {
|
||||
|
@ -413,6 +417,23 @@ func (sub *Subcommand) AddAliases(commandName string, aliases ...string) {
|
|||
command.Aliases = append(command.Aliases, aliases...)
|
||||
}
|
||||
|
||||
func lowerFirstLetter(name string) string {
|
||||
return strings.ToLower(string(name[0])) + name[1:]
|
||||
// DeriveIntents derives all possible gateway intents from the method handlers
|
||||
// and middlewares.
|
||||
func (sub *Subcommand) DeriveIntents() gateway.Intents {
|
||||
var intents gateway.Intents
|
||||
|
||||
for _, event := range sub.Events {
|
||||
intents |= event.intents()
|
||||
}
|
||||
for _, command := range sub.Commands {
|
||||
intents |= command.intents()
|
||||
}
|
||||
if sub.plumbed != nil {
|
||||
intents |= sub.plumbed.intents()
|
||||
}
|
||||
for _, middleware := range sub.globalmws {
|
||||
intents |= middleware.intents()
|
||||
}
|
||||
|
||||
return intents
|
||||
}
|
||||
|
|
|
@ -248,7 +248,7 @@ type (
|
|||
MessageID discord.MessageID `json:"message_id"`
|
||||
GuildID discord.GuildID `json:"guild_id,omitempty"`
|
||||
}
|
||||
MessageReactionRemoveEmoji struct {
|
||||
MessageReactionRemoveEmojiEvent struct {
|
||||
ChannelID discord.ChannelID `json:"channel_id"`
|
||||
MessageID discord.MessageID `json:"message_id"`
|
||||
Emoji discord.Emoji `json:"emoji"`
|
||||
|
|
|
@ -3,6 +3,7 @@ package gateway
|
|||
// Event is any event struct. They have an "Event" suffixed to them.
|
||||
type Event = interface{}
|
||||
|
||||
// EventCreator maps an event type string to a constructor.
|
||||
var EventCreator = map[string]func() Event{
|
||||
"HELLO": func() Event { return new(HelloEvent) },
|
||||
"READY": func() Event { return new(ReadyEvent) },
|
||||
|
@ -44,9 +45,10 @@ var EventCreator = map[string]func() Event{
|
|||
"MESSAGE_DELETE": func() Event { return new(MessageDeleteEvent) },
|
||||
"MESSAGE_DELETE_BULK": func() Event { return new(MessageDeleteBulkEvent) },
|
||||
|
||||
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
|
||||
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
|
||||
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
|
||||
"MESSAGE_REACTION_ADD": func() Event { return new(MessageReactionAddEvent) },
|
||||
"MESSAGE_REACTION_REMOVE": func() Event { return new(MessageReactionRemoveEvent) },
|
||||
"MESSAGE_REACTION_REMOVE_ALL": func() Event { return new(MessageReactionRemoveAllEvent) },
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": func() Event { return new(MessageReactionRemoveEmojiEvent) },
|
||||
|
||||
"MESSAGE_ACK": func() Event { return new(MessageAckEvent) },
|
||||
|
||||
|
|
|
@ -114,7 +114,7 @@ func NewGatewayWithIntents(token string, intents ...Intents) (*Gateway, error) {
|
|||
}
|
||||
|
||||
for _, intent := range intents {
|
||||
g.AddIntent(intent)
|
||||
g.AddIntents(intent)
|
||||
}
|
||||
|
||||
return g, nil
|
||||
|
@ -154,9 +154,9 @@ func NewCustomGateway(gatewayURL, token string) *Gateway {
|
|||
}
|
||||
}
|
||||
|
||||
// AddIntent adds a Gateway Intent before connecting to the Gateway. As
|
||||
// such, this function will only work before Open() is called.
|
||||
func (g *Gateway) AddIntent(i Intents) {
|
||||
// AddIntents adds a Gateway Intent before connecting to the Gateway. As such,
|
||||
// this function will only work before Open() is called.
|
||||
func (g *Gateway) AddIntents(i Intents) {
|
||||
g.Identifier.Intents |= i
|
||||
}
|
||||
|
||||
|
|
|
@ -55,28 +55,6 @@ func (i *IdentifyData) SetShard(id, num int) {
|
|||
i.Shard[0], i.Shard[1] = id, num
|
||||
}
|
||||
|
||||
// Intents for the new Discord API feature, documented at
|
||||
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
|
||||
type Intents uint32
|
||||
|
||||
const (
|
||||
IntentGuilds Intents = 1 << iota
|
||||
IntentGuildMembers
|
||||
IntentGuildBans
|
||||
IntentGuildEmojis
|
||||
IntentGuildIntegrations
|
||||
IntentGuildWebhooks
|
||||
IntentGuildInvites
|
||||
IntentGuildVoiceStates
|
||||
IntentGuildPresences
|
||||
IntentGuildMessages
|
||||
IntentGuildMessageReactions
|
||||
IntentGuildMessageTyping
|
||||
IntentDirectMessages
|
||||
IntentDirectMessageReactions
|
||||
IntentDirectMessageTyping
|
||||
)
|
||||
|
||||
type Identifier struct {
|
||||
IdentifyData
|
||||
|
||||
|
|
44
gateway/intents.go
Normal file
44
gateway/intents.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package gateway
|
||||
|
||||
import "github.com/diamondburned/arikawa/v2/discord"
|
||||
|
||||
// Intents for the new Discord API feature, documented at
|
||||
// https://discordapp.com/developers/docs/topics/gateway#gateway-intents.
|
||||
type Intents uint32
|
||||
|
||||
const (
|
||||
IntentGuilds Intents = 1 << iota
|
||||
IntentGuildMembers
|
||||
IntentGuildBans
|
||||
IntentGuildEmojis
|
||||
IntentGuildIntegrations
|
||||
IntentGuildWebhooks
|
||||
IntentGuildInvites
|
||||
IntentGuildVoiceStates
|
||||
IntentGuildPresences
|
||||
IntentGuildMessages
|
||||
IntentGuildMessageReactions
|
||||
IntentGuildMessageTyping
|
||||
IntentDirectMessages
|
||||
IntentDirectMessageReactions
|
||||
IntentDirectMessageTyping
|
||||
)
|
||||
|
||||
// PrivilegedIntents contains a list of privileged intents that Discord requires
|
||||
// bots to have these intents explicitly enabled in the Developer Portal.
|
||||
var PrivilegedIntents = []Intents{
|
||||
IntentGuildPresences,
|
||||
IntentGuildMembers,
|
||||
}
|
||||
|
||||
// Has returns true if i has the given intents.
|
||||
func (i Intents) Has(intents Intents) bool {
|
||||
return discord.HasFlag(uint64(i), uint64(intents))
|
||||
}
|
||||
|
||||
// IsPrivileged returns true for each of the boolean that indicates the type of
|
||||
// the privilege.
|
||||
func (i Intents) IsPrivileged() (presences, member bool) {
|
||||
// Keep this in sync with PrivilegedIntents.
|
||||
return i.Has(IntentGuildPresences), i.Has(IntentGuildMembers)
|
||||
}
|
47
gateway/intents_map.go
Normal file
47
gateway/intents_map.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package gateway
|
||||
|
||||
// EventIntents maps event types to intents.
|
||||
var EventIntents = map[string]Intents{
|
||||
"GUILD_CREATE": IntentGuilds,
|
||||
"GUILD_UPDATE": IntentGuilds,
|
||||
"GUILD_DELETE": IntentGuilds,
|
||||
"GUILD_ROLE_CREATE": IntentGuilds,
|
||||
"GUILD_ROLE_UPDATE": IntentGuilds,
|
||||
"GUILD_ROLE_DELETE": IntentGuilds,
|
||||
"CHANNEL_CREATE": IntentGuilds,
|
||||
"CHANNEL_UPDATE": IntentGuilds,
|
||||
"CHANNEL_DELETE": IntentGuilds,
|
||||
"CHANNEL_PINS_UPDATE": IntentGuilds | IntentDirectMessages,
|
||||
|
||||
"GUILD_MEMBER_ADD": IntentGuildMembers,
|
||||
"GUILD_MEMBER_REMOVE": IntentGuildMembers,
|
||||
"GUILD_MEMBER_UPDATE": IntentGuildMembers,
|
||||
|
||||
"GUILD_BAN_ADD": IntentGuildBans,
|
||||
"GUILD_BAN_REMOVE": IntentGuildBans,
|
||||
|
||||
"GUILD_EMOJIS_UPDATE": IntentGuildEmojis,
|
||||
|
||||
"GUILD_INTEGRATIONS_UPDATE": IntentGuildIntegrations,
|
||||
|
||||
"WEBHOOKS_UPDATE": IntentGuildWebhooks,
|
||||
|
||||
"INVITE_CREATE": IntentGuildInvites,
|
||||
"INVITE_DELETE": IntentGuildInvites,
|
||||
|
||||
"VOICE_STATE_UPDATE": IntentGuildVoiceStates,
|
||||
|
||||
"PRESENCE_UPDATE": IntentGuildPresences,
|
||||
|
||||
"MESSAGE_CREATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_UPDATE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE": IntentGuildMessages | IntentDirectMessages,
|
||||
"MESSAGE_DELETE_BULK": IntentGuildMessages,
|
||||
|
||||
"MESSAGE_REACTION_ADD": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_ALL": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": IntentGuildMessageReactions | IntentDirectMessageReactions,
|
||||
|
||||
"TYPING_START": IntentGuildMessageTyping | IntentDirectMessageTyping,
|
||||
}
|
|
@ -240,7 +240,7 @@ func (s *State) onEvent(iface interface{}) {
|
|||
return true
|
||||
})
|
||||
|
||||
case *gateway.MessageReactionRemoveEmoji:
|
||||
case *gateway.MessageReactionRemoveEmojiEvent:
|
||||
s.editMessage(ev.ChannelID, ev.MessageID, func(m *discord.Message) bool {
|
||||
var i = findReaction(m.Reactions, ev.Emoji)
|
||||
if i < 0 {
|
||||
|
|
|
@ -25,21 +25,19 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Handler is a container for command handlers. A zero-value instance is a valid
|
||||
// instance.
|
||||
type Handler struct {
|
||||
// Synchronous controls whether to spawn each event handler in its own
|
||||
// goroutine. Default false (meaning goroutines are spawned).
|
||||
Synchronous bool
|
||||
|
||||
handlers map[uint64]handler
|
||||
horders []uint64
|
||||
hserial uint64
|
||||
hmutex sync.RWMutex
|
||||
mutex sync.RWMutex
|
||||
slab slab
|
||||
}
|
||||
|
||||
func New() *Handler {
|
||||
return &Handler{
|
||||
handlers: map[uint64]handler{},
|
||||
}
|
||||
return &Handler{}
|
||||
}
|
||||
|
||||
// Call calls all handlers with the given event. This is an internal method; use
|
||||
|
@ -48,52 +46,18 @@ func (h *Handler) Call(ev interface{}) {
|
|||
var evV = reflect.ValueOf(ev)
|
||||
var evT = evV.Type()
|
||||
|
||||
h.hmutex.RLock()
|
||||
defer h.hmutex.RUnlock()
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
for _, order := range h.horders {
|
||||
handler, ok := h.handlers[order]
|
||||
if !ok {
|
||||
// This shouldn't ever happen, but we're adding this just in case.
|
||||
continue
|
||||
}
|
||||
|
||||
if handler.not(evT) {
|
||||
for _, entry := range h.slab.Entries {
|
||||
if entry.isInvalid() || entry.not(evT) {
|
||||
continue
|
||||
}
|
||||
|
||||
if h.Synchronous {
|
||||
handler.call(evV)
|
||||
entry.call(evV)
|
||||
} else {
|
||||
go handler.call(evV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CallDirect is the same as Call, but only calls those event handlers that
|
||||
// listen for this specific event, i.e. that aren't interface handlers.
|
||||
func (h *Handler) CallDirect(ev interface{}) {
|
||||
var evV = reflect.ValueOf(ev)
|
||||
var evT = evV.Type()
|
||||
|
||||
h.hmutex.RLock()
|
||||
defer h.hmutex.RUnlock()
|
||||
|
||||
for _, order := range h.horders {
|
||||
handler, ok := h.handlers[order]
|
||||
if !ok {
|
||||
// This shouldn't ever happen, but we're adding this just in case.
|
||||
continue
|
||||
}
|
||||
|
||||
if evT != handler.event {
|
||||
continue
|
||||
}
|
||||
|
||||
if h.Synchronous {
|
||||
handler.call(evV)
|
||||
} else {
|
||||
go handler.call(evV)
|
||||
go entry.call(evV)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -213,47 +177,16 @@ func (h *Handler) addHandler(fn interface{}) (rm func(), err error) {
|
|||
return nil, errors.Wrap(err, "handler reflect failed")
|
||||
}
|
||||
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
|
||||
// Get the current counter value and increment the counter:
|
||||
serial := h.hserial
|
||||
h.hserial++
|
||||
|
||||
// Create a map if there's none:
|
||||
if h.handlers == nil {
|
||||
h.handlers = map[uint64]handler{}
|
||||
}
|
||||
|
||||
// Use the serial for the map:
|
||||
h.handlers[serial] = *r
|
||||
|
||||
// Append the serial into the list of keys:
|
||||
h.horders = append(h.horders, serial)
|
||||
h.mutex.Lock()
|
||||
id := h.slab.Put(r)
|
||||
h.mutex.Unlock()
|
||||
|
||||
return func() {
|
||||
h.hmutex.Lock()
|
||||
defer h.hmutex.Unlock()
|
||||
h.mutex.Lock()
|
||||
popped := h.slab.Pop(id)
|
||||
h.mutex.Unlock()
|
||||
|
||||
// Take and delete the handler from the map, but return if we can't find
|
||||
// it.
|
||||
hd, ok := h.handlers[serial]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(h.handlers, serial)
|
||||
|
||||
// Delete the key from the orders slice:
|
||||
for i, order := range h.horders {
|
||||
if order == serial {
|
||||
h.horders = append(h.horders[:i], h.horders[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up the handler.
|
||||
hd.cleanup()
|
||||
popped.cleanup()
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -267,7 +200,7 @@ type handler struct {
|
|||
// newHandler reflects either a channel or a function into a handler. A function
|
||||
// must only have a single argument being the event and no return, and a channel
|
||||
// must have the event type as the underlying type.
|
||||
func newHandler(unknown interface{}) (*handler, error) {
|
||||
func newHandler(unknown interface{}) (handler, error) {
|
||||
fnV := reflect.ValueOf(unknown)
|
||||
fnT := fnV.Type()
|
||||
|
||||
|
@ -279,11 +212,11 @@ func newHandler(unknown interface{}) (*handler, error) {
|
|||
switch fnT.Kind() {
|
||||
case reflect.Func:
|
||||
if fnT.NumIn() != 1 {
|
||||
return nil, errors.New("function can only accept 1 event as argument")
|
||||
return handler, errors.New("function can only accept 1 event as argument")
|
||||
}
|
||||
|
||||
if fnT.NumOut() > 0 {
|
||||
return nil, errors.New("function can't accept returns")
|
||||
return handler, errors.New("function can't accept returns")
|
||||
}
|
||||
|
||||
handler.event = fnT.In(0)
|
||||
|
@ -293,19 +226,19 @@ func newHandler(unknown interface{}) (*handler, error) {
|
|||
handler.chanclose = reflect.ValueOf(make(chan struct{}))
|
||||
|
||||
default:
|
||||
return nil, errors.New("given interface is not a function or channel")
|
||||
return handler, errors.New("given interface is not a function or channel")
|
||||
}
|
||||
|
||||
var kind = handler.event.Kind()
|
||||
|
||||
// Accept either pointer type or interface{} type
|
||||
if kind != reflect.Ptr && kind != reflect.Interface {
|
||||
return nil, errors.New("first argument is not pointer")
|
||||
return handler, errors.New("first argument is not pointer")
|
||||
}
|
||||
|
||||
handler.isIface = kind == reflect.Interface
|
||||
|
||||
return &handler, nil
|
||||
return handler, nil
|
||||
}
|
||||
|
||||
func (h handler) not(event reflect.Type) bool {
|
||||
|
@ -316,7 +249,7 @@ func (h handler) not(event reflect.Type) bool {
|
|||
return h.event != event
|
||||
}
|
||||
|
||||
func (h *handler) call(event reflect.Value) {
|
||||
func (h handler) call(event reflect.Value) {
|
||||
if h.chanclose.IsValid() {
|
||||
reflect.Select([]reflect.SelectCase{
|
||||
{Dir: reflect.SelectSend, Chan: h.callback, Send: event},
|
||||
|
@ -327,7 +260,7 @@ func (h *handler) call(event reflect.Value) {
|
|||
}
|
||||
}
|
||||
|
||||
func (h *handler) cleanup() {
|
||||
func (h handler) cleanup() {
|
||||
if h.chanclose.IsValid() {
|
||||
// Closing this channel will force all ongoing selects to return
|
||||
// immediately.
|
||||
|
|
|
@ -20,9 +20,7 @@ func newMessage(content string) *gateway.MessageCreateEvent {
|
|||
func TestCall(t *testing.T) {
|
||||
var results = make(chan string)
|
||||
|
||||
h := &Handler{
|
||||
handlers: map[uint64]handler{},
|
||||
}
|
||||
h := &Handler{}
|
||||
|
||||
// Add handler test
|
||||
rm := h.AddHandler(func(m *gateway.MessageCreateEvent) {
|
||||
|
|
44
utils/handler/slab.go
Normal file
44
utils/handler/slab.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package handler
|
||||
|
||||
type slabEntry struct {
|
||||
handler
|
||||
index int
|
||||
}
|
||||
|
||||
func (entry slabEntry) isInvalid() bool {
|
||||
return entry.index != -1
|
||||
}
|
||||
|
||||
// slab is an implementation of the internal handler free list.
|
||||
type slab struct {
|
||||
Entries []slabEntry
|
||||
free int
|
||||
}
|
||||
|
||||
func (s *slab) Put(entry handler) int {
|
||||
if s.free == len(s.Entries) {
|
||||
index := len(s.Entries)
|
||||
s.Entries = append(s.Entries, slabEntry{entry, -1})
|
||||
s.free++
|
||||
return index
|
||||
}
|
||||
|
||||
next := s.Entries[s.free].index
|
||||
s.Entries[s.free] = slabEntry{entry, -1}
|
||||
|
||||
i := s.free
|
||||
s.free = next
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
func (s *slab) Get(i int) handler {
|
||||
return s.Entries[i].handler
|
||||
}
|
||||
|
||||
func (s *slab) Pop(i int) handler {
|
||||
popped := s.Entries[i].handler
|
||||
s.Entries[i] = slabEntry{handler{}, s.free}
|
||||
s.free = i
|
||||
return popped
|
||||
}
|
|
@ -85,24 +85,26 @@ func (c *Conn) Dial(ctx context.Context, addr string) error {
|
|||
// BUG which prevents stream compression.
|
||||
// See https://github.com/golang/go/issues/31514.
|
||||
|
||||
conn, _, err := c.dialer.DialContext(ctx, addr, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
||||
events := make(chan Event, WSBuffer)
|
||||
go startReadLoop(conn, events)
|
||||
var err error
|
||||
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
c.Conn = conn
|
||||
c.events = events
|
||||
c.Conn, _, err = c.dialer.DialContext(ctx, addr, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to dial WS")
|
||||
}
|
||||
|
||||
c.events = make(chan Event, WSBuffer)
|
||||
go startReadLoop(c.Conn, c.events)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) Listen() <-chan Event {
|
||||
c.mutex.Lock()
|
||||
defer c.mutex.Unlock()
|
||||
|
||||
return c.events
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue