caddy-sip-guardian/dns_whitelist.go
Ryan Malloy 5cf34eb3c0 Add DNS-aware whitelisting feature
Support for whitelisting SIP trunks and providers by hostname or SRV
record with automatic IP resolution and periodic refresh.

Features:
- Hostname resolution via A/AAAA records
- SRV record resolution (e.g., _sip._udp.provider.com)
- Configurable refresh interval (default 5m)
- Stale entry handling when DNS fails
- Admin API endpoints for DNS whitelist management
- Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh

This allows whitelisting by provider name rather than tracking
constantly-changing IP addresses.
2025-12-08 00:46:43 -07:00

383 lines
9.5 KiB
Go

// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian.
// This allows whitelisting by hostname, A/AAAA records, and SRV records.
package sipguardian
import (
"context"
"net"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// DNSWhitelistConfig holds configuration for DNS-based whitelisting
type DNSWhitelistConfig struct {
// Hostnames to resolve and whitelist (A/AAAA records)
Hostnames []string `json:"hostnames,omitempty"`
// SRV records to resolve (e.g., "_sip._udp.provider.com")
// Will resolve SRV -> hostnames -> IPs
SRVRecords []string `json:"srv_records,omitempty"`
// RefreshInterval for DNS lookups (default: 5m)
RefreshInterval time.Duration `json:"refresh_interval,omitempty"`
// AllowStale allows using cached IPs if DNS refresh fails (default: true)
AllowStale bool `json:"allow_stale,omitempty"`
// ResolveTimeout for individual DNS queries (default: 10s)
ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"`
}
// DNSWhitelist manages DNS-based IP whitelisting with automatic refresh
type DNSWhitelist struct {
config DNSWhitelistConfig
logger *zap.Logger
// Resolved IPs with their source info
resolvedIPs map[string]*ResolvedEntry
mu sync.RWMutex
// For graceful shutdown
stopCh chan struct{}
wg sync.WaitGroup
}
// ResolvedEntry tracks where an IP came from
type ResolvedEntry struct {
IP string `json:"ip"`
Source string `json:"source"` // hostname or SRV record
SourceType string `json:"source_type"` // "hostname", "srv"
ResolvedAt time.Time `json:"resolved_at"`
ExpiresAt time.Time `json:"expires_at"`
TTL int `json:"ttl"` // DNS TTL in seconds
}
// NewDNSWhitelist creates a new DNS whitelist manager
func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist {
// Set defaults
if config.RefreshInterval == 0 {
config.RefreshInterval = 5 * time.Minute
}
if config.ResolveTimeout == 0 {
config.ResolveTimeout = 10 * time.Second
}
if !config.AllowStale {
config.AllowStale = true // Default to allowing stale on DNS failure
}
return &DNSWhitelist{
config: config,
logger: logger,
resolvedIPs: make(map[string]*ResolvedEntry),
stopCh: make(chan struct{}),
}
}
// Start begins the DNS refresh loop
func (d *DNSWhitelist) Start() error {
// Do initial resolution
d.refreshAll()
// Start background refresh
d.wg.Add(1)
go d.refreshLoop()
d.logger.Info("DNS whitelist started",
zap.Int("hostnames", len(d.config.Hostnames)),
zap.Int("srv_records", len(d.config.SRVRecords)),
zap.Duration("refresh_interval", d.config.RefreshInterval),
)
return nil
}
// Stop stops the DNS refresh loop
func (d *DNSWhitelist) Stop() {
close(d.stopCh)
d.wg.Wait()
}
// Contains checks if an IP is in the DNS whitelist
func (d *DNSWhitelist) Contains(ip string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
entry, exists := d.resolvedIPs[ip]
if !exists {
return false
}
// Check if entry is still valid
if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale {
return false
}
return true
}
// GetSource returns the source info for a whitelisted IP
func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry {
d.mu.RLock()
defer d.mu.RUnlock()
if entry, exists := d.resolvedIPs[ip]; exists {
// Return copy to avoid race conditions
entryCopy := *entry
return &entryCopy
}
return nil
}
// GetResolvedIPs returns all currently resolved IPs
func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry {
d.mu.RLock()
defer d.mu.RUnlock()
entries := make([]ResolvedEntry, 0, len(d.resolvedIPs))
for _, entry := range d.resolvedIPs {
entries = append(entries, *entry)
}
return entries
}
// refreshLoop periodically refreshes DNS entries
func (d *DNSWhitelist) refreshLoop() {
defer d.wg.Done()
ticker := time.NewTicker(d.config.RefreshInterval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.refreshAll()
}
}
}
// refreshAll resolves all configured hostnames and SRV records
func (d *DNSWhitelist) refreshAll() {
ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1))
defer cancel()
newResolved := make(map[string]*ResolvedEntry)
now := time.Now()
defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin
// Resolve hostnames (A/AAAA records)
for _, hostname := range d.config.Hostnames {
ips, err := d.resolveHostname(ctx, hostname)
if err != nil {
d.logger.Warn("Failed to resolve hostname",
zap.String("hostname", hostname),
zap.Error(err),
)
// Keep existing entries if AllowStale
if d.config.AllowStale {
d.copyExistingEntries(newResolved, hostname, "hostname")
}
continue
}
for _, ip := range ips {
newResolved[ip] = &ResolvedEntry{
IP: ip,
Source: hostname,
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: defaultExpiry,
}
}
d.logger.Debug("Resolved hostname",
zap.String("hostname", hostname),
zap.Strings("ips", ips),
)
}
// Resolve SRV records
for _, srv := range d.config.SRVRecords {
ips, targets, err := d.resolveSRV(ctx, srv)
if err != nil {
d.logger.Warn("Failed to resolve SRV record",
zap.String("srv", srv),
zap.Error(err),
)
// Keep existing entries if AllowStale
if d.config.AllowStale {
d.copyExistingEntries(newResolved, srv, "srv")
}
continue
}
for _, ip := range ips {
newResolved[ip] = &ResolvedEntry{
IP: ip,
Source: srv,
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: defaultExpiry,
}
}
d.logger.Debug("Resolved SRV record",
zap.String("srv", srv),
zap.Strings("targets", targets),
zap.Strings("ips", ips),
)
}
// Update the map atomically
d.mu.Lock()
d.resolvedIPs = newResolved
d.mu.Unlock()
d.logger.Info("DNS whitelist refreshed",
zap.Int("total_ips", len(newResolved)),
)
}
// copyExistingEntries copies existing entries for a source to the new map
func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) {
d.mu.RLock()
defer d.mu.RUnlock()
for ip, entry := range d.resolvedIPs {
if entry.Source == source && entry.SourceType == sourceType {
newMap[ip] = entry
}
}
}
// resolveHostname resolves a hostname to IP addresses
func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) {
// Handle case where hostname is already an IP
if ip := net.ParseIP(hostname); ip != nil {
return []string{hostname}, nil
}
resolver := net.DefaultResolver
addrs, err := resolver.LookupIPAddr(ctx, hostname)
if err != nil {
return nil, err
}
ips := make([]string, 0, len(addrs))
for _, addr := range addrs {
ips = append(ips, addr.IP.String())
}
return ips, nil
}
// resolveSRV resolves an SRV record and all its targets
func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) {
// Parse SRV record format: _service._proto.name or just name
// Example: _sip._udp.provider.com or _sip._tcp.provider.com
var service, proto, name string
parts := strings.Split(srvRecord, ".")
if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") {
service = strings.TrimPrefix(parts[0], "_")
proto = strings.TrimPrefix(parts[1], "_")
name = strings.Join(parts[2:], ".")
} else {
// Assume it's a plain domain, try common SIP SRV patterns
// Try _sip._udp first, then _sip._tcp
service = "sip"
proto = "udp"
name = srvRecord
}
resolver := net.DefaultResolver
// Try to resolve SRV
_, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name)
if err != nil {
// If UDP fails, try TCP
if proto == "udp" {
_, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name)
if err != nil {
// Fall back to A record lookup on the original name
directIPs, aErr := d.resolveHostname(ctx, srvRecord)
if aErr != nil {
return nil, nil, err // Return original SRV error
}
return directIPs, []string{srvRecord}, nil
}
} else {
return nil, nil, err
}
}
// Resolve each SRV target to IPs
seenIPs := make(map[string]bool)
for _, srv := range srvRecords {
target := strings.TrimSuffix(srv.Target, ".")
targets = append(targets, target)
targetIPs, err := d.resolveHostname(ctx, target)
if err != nil {
d.logger.Warn("Failed to resolve SRV target",
zap.String("srv", srvRecord),
zap.String("target", target),
zap.Error(err),
)
continue
}
for _, ip := range targetIPs {
if !seenIPs[ip] {
seenIPs[ip] = true
ips = append(ips, ip)
}
}
}
return ips, targets, nil
}
// ForceRefresh triggers an immediate refresh of DNS entries
func (d *DNSWhitelist) ForceRefresh() {
d.refreshAll()
}
// Stats returns statistics about the DNS whitelist
func (d *DNSWhitelist) Stats() map[string]interface{} {
d.mu.RLock()
defer d.mu.RUnlock()
hostnameIPs := 0
srvIPs := 0
staleCount := 0
now := time.Now()
for _, entry := range d.resolvedIPs {
switch entry.SourceType {
case "hostname":
hostnameIPs++
case "srv":
srvIPs++
}
if now.After(entry.ExpiresAt) {
staleCount++
}
}
return map[string]interface{}{
"total_ips": len(d.resolvedIPs),
"hostname_ips": hostnameIPs,
"srv_ips": srvIPs,
"stale_count": staleCount,
"configured_hosts": len(d.config.Hostnames),
"configured_srv": len(d.config.SRVRecords),
"refresh_interval": d.config.RefreshInterval.String(),
"allow_stale": d.config.AllowStale,
}
}