caddy-sip-guardian/dns_whitelist_test.go
Ryan Malloy 5cf34eb3c0 Add DNS-aware whitelisting feature
Support for whitelisting SIP trunks and providers by hostname or SRV
record with automatic IP resolution and periodic refresh.

Features:
- Hostname resolution via A/AAAA records
- SRV record resolution (e.g., _sip._udp.provider.com)
- Configurable refresh interval (default 5m)
- Stale entry handling when DNS fails
- Admin API endpoints for DNS whitelist management
- Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh

This allows whitelisting by provider name rather than tracking
constantly-changing IP addresses.
2025-12-08 00:46:43 -07:00

501 lines
13 KiB
Go

package sipguardian
import (
"net"
"testing"
"time"
"go.uber.org/zap"
)
// TestDNSWhitelistConfig tests configuration defaults and validation
func TestDNSWhitelistConfig(t *testing.T) {
logger := zap.NewNop()
tests := []struct {
name string
config DNSWhitelistConfig
expectedRefresh time.Duration
expectedTimeout time.Duration
}{
{
name: "empty config uses defaults",
config: DNSWhitelistConfig{},
expectedRefresh: 5 * time.Minute,
expectedTimeout: 10 * time.Second,
},
{
name: "custom refresh interval",
config: DNSWhitelistConfig{
RefreshInterval: 10 * time.Minute,
},
expectedRefresh: 10 * time.Minute,
expectedTimeout: 10 * time.Second,
},
{
name: "custom timeout",
config: DNSWhitelistConfig{
ResolveTimeout: 30 * time.Second,
},
expectedRefresh: 5 * time.Minute,
expectedTimeout: 30 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wl := NewDNSWhitelist(tt.config, logger)
if wl.config.RefreshInterval != tt.expectedRefresh {
t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh)
}
if wl.config.ResolveTimeout != tt.expectedTimeout {
t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout)
}
if !wl.config.AllowStale {
t.Error("AllowStale should default to true")
}
})
}
}
// TestDNSWhitelistContains tests the IP lookup functionality
func TestDNSWhitelistContains(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Manually add some entries for testing
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
tests := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.50", true},
{"8.8.8.8", false},
{"192.168.1.101", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
if got := wl.Contains(tt.ip); got != tt.expected {
t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected)
}
})
}
}
// TestDNSWhitelistExpiredEntries tests handling of expired entries
func TestDNSWhitelistExpiredEntries(t *testing.T) {
logger := zap.NewNop()
t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) {
wl := NewDNSWhitelist(DNSWhitelistConfig{
AllowStale: false,
}, logger)
// Override the default
wl.config.AllowStale = false
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now.Add(-1 * time.Hour),
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
}
if wl.Contains("192.168.1.100") {
t.Error("Expired entry should not match when AllowStale=false")
}
})
t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) {
wl := NewDNSWhitelist(DNSWhitelistConfig{
AllowStale: true,
}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now.Add(-1 * time.Hour),
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
}
if !wl.Contains("192.168.1.100") {
t.Error("Expired entry should match when AllowStale=true")
}
})
}
// TestDNSWhitelistGetSource tests source info retrieval
func TestDNSWhitelistGetSource(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "pbx.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
t.Run("existing IP returns source", func(t *testing.T) {
source := wl.GetSource("192.168.1.100")
if source == nil {
t.Fatal("GetSource returned nil for existing IP")
}
if source.Source != "pbx.example.com" {
t.Errorf("Source = %s, want pbx.example.com", source.Source)
}
if source.SourceType != "hostname" {
t.Errorf("SourceType = %s, want hostname", source.SourceType)
}
})
t.Run("non-existent IP returns nil", func(t *testing.T) {
source := wl.GetSource("8.8.8.8")
if source != nil {
t.Error("GetSource should return nil for non-existent IP")
}
})
}
// TestDNSWhitelistGetResolvedIPs tests listing all entries
func TestDNSWhitelistGetResolvedIPs(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Empty whitelist
entries := wl.GetResolvedIPs()
if len(entries) != 0 {
t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries))
}
// Add entries
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "host1.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
entries = wl.GetResolvedIPs()
if len(entries) != 2 {
t.Errorf("Expected 2 entries, got %d", len(entries))
}
}
// TestDNSWhitelistStats tests statistics reporting
func TestDNSWhitelistStats(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"host1.example.com", "host2.example.com"},
SRVRecords: []string{"_sip._udp.provider.com"},
RefreshInterval: 10 * time.Minute,
}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "host1.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{
IP: "192.168.1.101",
Source: "host2.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
stats := wl.Stats()
if stats["total_ips"] != 3 {
t.Errorf("total_ips = %v, want 3", stats["total_ips"])
}
if stats["hostname_ips"] != 2 {
t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"])
}
if stats["srv_ips"] != 1 {
t.Errorf("srv_ips = %v, want 1", stats["srv_ips"])
}
if stats["configured_hosts"] != 2 {
t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"])
}
if stats["configured_srv"] != 1 {
t.Errorf("configured_srv = %v, want 1", stats["configured_srv"])
}
}
// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs
func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Test that direct IP addresses are handled correctly
tests := []struct {
input string
expected string
}{
{"192.168.1.100", "192.168.1.100"},
{"10.0.0.1", "10.0.0.1"},
{"::1", "::1"},
{"2001:db8::1", "2001:db8::1"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
ips, err := wl.resolveHostname(nil, tt.input)
if err != nil {
t.Fatalf("resolveHostname(%s) error: %v", tt.input, err)
}
if len(ips) != 1 || ips[0] != tt.expected {
t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected)
}
})
}
}
// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test)
// This test requires network connectivity
func TestDNSWhitelistRealDNS(t *testing.T) {
if testing.Short() {
t.Skip("Skipping DNS integration test in short mode")
}
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"localhost"},
RefreshInterval: 1 * time.Minute,
ResolveTimeout: 5 * time.Second,
}, logger)
// Start and let it do initial resolution
if err := wl.Start(); err != nil {
t.Fatalf("Failed to start DNS whitelist: %v", err)
}
defer wl.Stop()
// Give it a moment to resolve
time.Sleep(100 * time.Millisecond)
// localhost should resolve to 127.0.0.1 or ::1
if !wl.Contains("127.0.0.1") && !wl.Contains("::1") {
entries := wl.GetResolvedIPs()
t.Errorf("Expected localhost to resolve, got entries: %+v", entries)
}
}
// TestDNSWhitelistStartStop tests the lifecycle management
func TestDNSWhitelistStartStop(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS
RefreshInterval: 100 * time.Millisecond,
}, logger)
// Start
if err := wl.Start(); err != nil {
t.Fatalf("Start() error: %v", err)
}
// Should have the IP immediately
if !wl.Contains("127.0.0.1") {
t.Error("Should contain 127.0.0.1 after start")
}
// Stop should complete without hanging
done := make(chan struct{})
go func() {
wl.Stop()
close(done)
}()
select {
case <-done:
// Good, stopped cleanly
case <-time.After(2 * time.Second):
t.Error("Stop() took too long")
}
}
// TestDNSWhitelistForceRefresh tests the manual refresh functionality
func TestDNSWhitelistForceRefresh(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1"},
RefreshInterval: 1 * time.Hour, // Long interval
}, logger)
// Don't start the refresh loop, just manually refresh
wl.ForceRefresh()
if !wl.Contains("127.0.0.1") {
t.Error("ForceRefresh should resolve IPs")
}
}
// TestDNSWhitelistConcurrency tests concurrent access to the whitelist
func TestDNSWhitelistConcurrency(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"},
RefreshInterval: 50 * time.Millisecond,
}, logger)
if err := wl.Start(); err != nil {
t.Fatalf("Start() error: %v", err)
}
defer wl.Stop()
// Run concurrent reads while refresh is happening
done := make(chan struct{})
go func() {
for i := 0; i < 1000; i++ {
wl.Contains("127.0.0.1")
wl.GetResolvedIPs()
wl.Stats()
}
close(done)
}()
select {
case <-done:
// Good
case <-time.After(5 * time.Second):
t.Error("Concurrent operations took too long")
}
}
// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian
func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) {
// Create a SIPGuardian with DNS whitelist config
g := &SIPGuardian{
WhitelistHosts: []string{"127.0.0.1"},
WhitelistSRV: []string{},
}
// Initialize maps
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
g.logger = zap.NewNop()
// Manually create DNS whitelist (normally done in Provision)
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
}, g.logger)
g.dnsWhitelist.ForceRefresh()
// Test IsWhitelisted
if !g.IsWhitelisted("127.0.0.1") {
t.Error("127.0.0.1 should be whitelisted via DNS")
}
if g.IsWhitelisted("8.8.8.8") {
t.Error("8.8.8.8 should NOT be whitelisted")
}
}
// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together
func TestSIPGuardianMixedWhitelist(t *testing.T) {
g := &SIPGuardian{
WhitelistCIDR: []string{"10.0.0.0/8"},
WhitelistHosts: []string{"127.0.0.1"},
}
// Initialize
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
g.logger = zap.NewNop()
// Parse CIDR whitelist (normally done in Provision)
for _, cidr := range g.WhitelistCIDR {
_, network, err := net.ParseCIDR(cidr)
if err == nil {
g.whitelistNets = append(g.whitelistNets, network)
}
}
// Set up DNS whitelist
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
}, g.logger)
g.dnsWhitelist.ForceRefresh()
// Test CIDR whitelist
if !g.IsWhitelisted("10.1.2.3") {
t.Error("10.1.2.3 should be whitelisted via CIDR")
}
// Test DNS whitelist
if !g.IsWhitelisted("127.0.0.1") {
t.Error("127.0.0.1 should be whitelisted via DNS")
}
// Test non-whitelisted
if g.IsWhitelisted("8.8.8.8") {
t.Error("8.8.8.8 should NOT be whitelisted")
}
}
// BenchmarkDNSWhitelistContains benchmarks lookup performance
func BenchmarkDNSWhitelistContains(b *testing.B) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Add 1000 entries
now := time.Now()
for i := 0; i < 1000; i++ {
ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256))
wl.resolvedIPs[ip] = &ResolvedEntry{
IP: ip,
Source: "test.example.com",
ExpiresAt: now.Add(10 * time.Minute),
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
wl.Contains("192.168.0.100")
}
}