voice: Allow setting udp.DialFunc

This commit fixes up the SetDialer method to accept a udp.DialFunc
instead of just a regular *net.Dialer, which is more flexible in that
the user can now control the UDP packet frequency properly.
This commit is contained in:
diamondburned 2022-04-03 17:48:15 -07:00
parent 816a53b1e1
commit ae24217e34
No known key found for this signature in database
GPG Key ID: D78C4471CE776659
3 changed files with 38 additions and 16 deletions

View File

@ -2,7 +2,6 @@ package voice
import (
"context"
"net"
"sync"
"time"
@ -66,10 +65,6 @@ var (
_ MainSession = (*state.State)(nil)
)
// UDPDialer is the UDP dialer function type. It's the function signature for
// udp.DialConnection.
type UDPDialer = func(ctx context.Context, addr string, ssrc uint32) (*udp.Connection, error)
// Session is a single voice session that wraps around the voice gateway and UDP
// connection.
type Session struct {
@ -146,7 +141,7 @@ func NewSessionCustom(ses MainSession, userID discord.UserID) *Session {
// SetUDPDialer sets the given dialer to be used for dialing UDP voice
// connections.
func (s *Session) SetUDPDialer(d *net.Dialer) {
func (s *Session) SetUDPDialer(d udp.DialFunc) {
s.udpManager.SetDialer(d)
}

View File

@ -17,9 +17,10 @@ import (
// to decrypt.
var ErrDecryptionFailed = errors.New("decryption failed")
// Dialer is the default dialer that this package uses for all its dialing.
var Dialer = net.Dialer{
Timeout: 10 * time.Second,
// defaultDialer is the default dialer that this package uses for all its
// dialing.
var defaultDialer = net.Dialer{
Timeout: 30 * time.Second,
}
// Connection represents a voice connection. It is not thread-safe.
@ -51,10 +52,31 @@ type Connection struct {
closed sync.Once
}
// DialFunc is the UDP dialer function type. It's the function signature for
// udp.DialConnection.
type DialFunc = func(ctx context.Context, addr string, ssrc uint32) (*Connection, error)
// Assert that this is the same.
var _ DialFunc = DialConnection
// DialFuncWithFrequency creates a new DialFunc with the given frame duration
// and time increment. See Connection's ResetFrequency method for more
// information.
func DialFuncWithFrequency(frameDuration time.Duration, timeIncr uint32) DialFunc {
return func(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
u, err := DialConnection(ctx, addr, ssrc)
if err != nil {
return nil, err
}
u.ResetFrequency(frameDuration, timeIncr)
return u, nil
}
}
// DialConnection dials the UDP connection using the given address and SSRC
// number.
func DialConnection(ctx context.Context, addr string, ssrc uint32) (*Connection, error) {
return DialConnectionCustom(ctx, &Dialer, addr, ssrc)
return DialConnectionCustom(ctx, &defaultDialer, addr, ssrc)
}
// DialConnectionCustom dials the UDP connection with a custom dialer.

View File

@ -2,7 +2,6 @@ package udp
import (
"context"
"net"
"sync"
"time"
@ -21,7 +20,7 @@ var ErrDialWhileUnpaused = errors.New("dial is called while manager is not pause
// Manager manages a UDP connection. It allows reconnecting. A Manager instance
// is thread-safe, meaning it can be used concurrently.
type Manager struct {
dialer *net.Dialer
dialer DialFunc
stopMu sync.Mutex
stopConn chan struct{}
@ -35,10 +34,16 @@ type Manager struct {
timeIncr uint32
}
// NewManager creates a new UDP connection manager with the defalt dialer.
// NewManager creates a new UDP connection manager with the default dial
// function DialConnection.
func NewManager() *Manager {
return NewManagerWithDialer(DialConnection)
}
// NewManagerWithDialer creates a UDP manager with an existing dial function.
func NewManagerWithDialer(dialer DialFunc) *Manager {
return &Manager{
dialer: &Dialer,
dialer: dialer,
stopConn: make(chan struct{}),
connLock: make(chan struct{}, 1),
}
@ -47,7 +52,7 @@ func NewManager() *Manager {
// SetDialer sets the manager's dialer. Calling this function while the Manager
// is working will cause a panic. Only call this method directly after
// construction.
func (m *Manager) SetDialer(d *net.Dialer) {
func (m *Manager) SetDialer(d DialFunc) {
select {
case m.connLock <- struct{}{}:
m.dialer = d
@ -136,7 +141,7 @@ func (m *Manager) Dial(ctx context.Context, addr string, ssrc uint32) (*Connecti
m.stopDial = cancel
m.stopMu.Unlock()
conn, err := DialConnectionCustom(ctx, m.dialer, addr, ssrc)
conn, err := m.dialer(ctx, addr, ssrc)
if err != nil {
// Unlock if we failed.
<-m.connLock