// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian. // This allows whitelisting by hostname, A/AAAA records, and SRV records. package sipguardian import ( "context" "net" "strings" "sync" "time" "go.uber.org/zap" ) // DNSWhitelistConfig holds configuration for DNS-based whitelisting type DNSWhitelistConfig struct { // Hostnames to resolve and whitelist (A/AAAA records) Hostnames []string `json:"hostnames,omitempty"` // SRV records to resolve (e.g., "_sip._udp.provider.com") // Will resolve SRV -> hostnames -> IPs SRVRecords []string `json:"srv_records,omitempty"` // RefreshInterval for DNS lookups (default: 5m) RefreshInterval time.Duration `json:"refresh_interval,omitempty"` // AllowStale allows using cached IPs if DNS refresh fails (default: true) AllowStale bool `json:"allow_stale,omitempty"` // ResolveTimeout for individual DNS queries (default: 10s) ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"` } // DNSWhitelist manages DNS-based IP whitelisting with automatic refresh type DNSWhitelist struct { config DNSWhitelistConfig logger *zap.Logger // Resolved IPs with their source info resolvedIPs map[string]*ResolvedEntry mu sync.RWMutex // For graceful shutdown stopCh chan struct{} wg sync.WaitGroup } // ResolvedEntry tracks where an IP came from type ResolvedEntry struct { IP string `json:"ip"` Source string `json:"source"` // hostname or SRV record SourceType string `json:"source_type"` // "hostname", "srv" ResolvedAt time.Time `json:"resolved_at"` ExpiresAt time.Time `json:"expires_at"` TTL int `json:"ttl"` // DNS TTL in seconds } // NewDNSWhitelist creates a new DNS whitelist manager func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist { // Set defaults if config.RefreshInterval == 0 { config.RefreshInterval = 5 * time.Minute } if config.ResolveTimeout == 0 { config.ResolveTimeout = 10 * time.Second } if !config.AllowStale { config.AllowStale = true // Default to allowing stale on DNS failure } return &DNSWhitelist{ config: config, logger: logger, resolvedIPs: make(map[string]*ResolvedEntry), stopCh: make(chan struct{}), } } // Start begins the DNS refresh loop func (d *DNSWhitelist) Start() error { // Do initial resolution d.refreshAll() // Start background refresh d.wg.Add(1) go d.refreshLoop() d.logger.Info("DNS whitelist started", zap.Int("hostnames", len(d.config.Hostnames)), zap.Int("srv_records", len(d.config.SRVRecords)), zap.Duration("refresh_interval", d.config.RefreshInterval), ) return nil } // Stop stops the DNS refresh loop func (d *DNSWhitelist) Stop() { close(d.stopCh) d.wg.Wait() } // Contains checks if an IP is in the DNS whitelist func (d *DNSWhitelist) Contains(ip string) bool { d.mu.RLock() defer d.mu.RUnlock() entry, exists := d.resolvedIPs[ip] if !exists { return false } // Check if entry is still valid if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale { return false } return true } // GetSource returns the source info for a whitelisted IP func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry { d.mu.RLock() defer d.mu.RUnlock() if entry, exists := d.resolvedIPs[ip]; exists { // Return copy to avoid race conditions entryCopy := *entry return &entryCopy } return nil } // GetResolvedIPs returns all currently resolved IPs func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry { d.mu.RLock() defer d.mu.RUnlock() entries := make([]ResolvedEntry, 0, len(d.resolvedIPs)) for _, entry := range d.resolvedIPs { entries = append(entries, *entry) } return entries } // refreshLoop periodically refreshes DNS entries func (d *DNSWhitelist) refreshLoop() { defer d.wg.Done() ticker := time.NewTicker(d.config.RefreshInterval) defer ticker.Stop() for { select { case <-d.stopCh: return case <-ticker.C: d.refreshAll() } } } // refreshAll resolves all configured hostnames and SRV records func (d *DNSWhitelist) refreshAll() { ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1)) defer cancel() newResolved := make(map[string]*ResolvedEntry) now := time.Now() defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin // Resolve hostnames (A/AAAA records) for _, hostname := range d.config.Hostnames { ips, err := d.resolveHostname(ctx, hostname) if err != nil { d.logger.Warn("Failed to resolve hostname", zap.String("hostname", hostname), zap.Error(err), ) // Keep existing entries if AllowStale if d.config.AllowStale { d.copyExistingEntries(newResolved, hostname, "hostname") } continue } for _, ip := range ips { newResolved[ip] = &ResolvedEntry{ IP: ip, Source: hostname, SourceType: "hostname", ResolvedAt: now, ExpiresAt: defaultExpiry, } } d.logger.Debug("Resolved hostname", zap.String("hostname", hostname), zap.Strings("ips", ips), ) } // Resolve SRV records for _, srv := range d.config.SRVRecords { ips, targets, err := d.resolveSRV(ctx, srv) if err != nil { d.logger.Warn("Failed to resolve SRV record", zap.String("srv", srv), zap.Error(err), ) // Keep existing entries if AllowStale if d.config.AllowStale { d.copyExistingEntries(newResolved, srv, "srv") } continue } for _, ip := range ips { newResolved[ip] = &ResolvedEntry{ IP: ip, Source: srv, SourceType: "srv", ResolvedAt: now, ExpiresAt: defaultExpiry, } } d.logger.Debug("Resolved SRV record", zap.String("srv", srv), zap.Strings("targets", targets), zap.Strings("ips", ips), ) } // Update the map atomically d.mu.Lock() d.resolvedIPs = newResolved d.mu.Unlock() d.logger.Info("DNS whitelist refreshed", zap.Int("total_ips", len(newResolved)), ) } // copyExistingEntries copies existing entries for a source to the new map func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) { d.mu.RLock() defer d.mu.RUnlock() for ip, entry := range d.resolvedIPs { if entry.Source == source && entry.SourceType == sourceType { newMap[ip] = entry } } } // resolveHostname resolves a hostname to IP addresses func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) { // Handle case where hostname is already an IP if ip := net.ParseIP(hostname); ip != nil { return []string{hostname}, nil } resolver := net.DefaultResolver addrs, err := resolver.LookupIPAddr(ctx, hostname) if err != nil { return nil, err } ips := make([]string, 0, len(addrs)) for _, addr := range addrs { ips = append(ips, addr.IP.String()) } return ips, nil } // resolveSRV resolves an SRV record and all its targets func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) { // Parse SRV record format: _service._proto.name or just name // Example: _sip._udp.provider.com or _sip._tcp.provider.com var service, proto, name string parts := strings.Split(srvRecord, ".") if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") { service = strings.TrimPrefix(parts[0], "_") proto = strings.TrimPrefix(parts[1], "_") name = strings.Join(parts[2:], ".") } else { // Assume it's a plain domain, try common SIP SRV patterns // Try _sip._udp first, then _sip._tcp service = "sip" proto = "udp" name = srvRecord } resolver := net.DefaultResolver // Try to resolve SRV _, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name) if err != nil { // If UDP fails, try TCP if proto == "udp" { _, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name) if err != nil { // Fall back to A record lookup on the original name directIPs, aErr := d.resolveHostname(ctx, srvRecord) if aErr != nil { return nil, nil, err // Return original SRV error } return directIPs, []string{srvRecord}, nil } } else { return nil, nil, err } } // Resolve each SRV target to IPs seenIPs := make(map[string]bool) for _, srv := range srvRecords { target := strings.TrimSuffix(srv.Target, ".") targets = append(targets, target) targetIPs, err := d.resolveHostname(ctx, target) if err != nil { d.logger.Warn("Failed to resolve SRV target", zap.String("srv", srvRecord), zap.String("target", target), zap.Error(err), ) continue } for _, ip := range targetIPs { if !seenIPs[ip] { seenIPs[ip] = true ips = append(ips, ip) } } } return ips, targets, nil } // ForceRefresh triggers an immediate refresh of DNS entries func (d *DNSWhitelist) ForceRefresh() { d.refreshAll() } // Stats returns statistics about the DNS whitelist func (d *DNSWhitelist) Stats() map[string]interface{} { d.mu.RLock() defer d.mu.RUnlock() hostnameIPs := 0 srvIPs := 0 staleCount := 0 now := time.Now() for _, entry := range d.resolvedIPs { switch entry.SourceType { case "hostname": hostnameIPs++ case "srv": srvIPs++ } if now.After(entry.ExpiresAt) { staleCount++ } } return map[string]interface{}{ "total_ips": len(d.resolvedIPs), "hostname_ips": hostnameIPs, "srv_ips": srvIPs, "stale_count": staleCount, "configured_hosts": len(d.config.Hostnames), "configured_srv": len(d.config.SRVRecords), "refresh_interval": d.config.RefreshInterval.String(), "allow_stale": d.config.AllowStale, } }