Support for whitelisting SIP trunks and providers by hostname or SRV record with automatic IP resolution and periodic refresh. Features: - Hostname resolution via A/AAAA records - SRV record resolution (e.g., _sip._udp.provider.com) - Configurable refresh interval (default 5m) - Stale entry handling when DNS fails - Admin API endpoints for DNS whitelist management - Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh This allows whitelisting by provider name rather than tracking constantly-changing IP addresses.
501 lines
13 KiB
Go
501 lines
13 KiB
Go
package sipguardian
|
|
|
|
import (
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// TestDNSWhitelistConfig tests configuration defaults and validation
|
|
func TestDNSWhitelistConfig(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
|
|
tests := []struct {
|
|
name string
|
|
config DNSWhitelistConfig
|
|
expectedRefresh time.Duration
|
|
expectedTimeout time.Duration
|
|
}{
|
|
{
|
|
name: "empty config uses defaults",
|
|
config: DNSWhitelistConfig{},
|
|
expectedRefresh: 5 * time.Minute,
|
|
expectedTimeout: 10 * time.Second,
|
|
},
|
|
{
|
|
name: "custom refresh interval",
|
|
config: DNSWhitelistConfig{
|
|
RefreshInterval: 10 * time.Minute,
|
|
},
|
|
expectedRefresh: 10 * time.Minute,
|
|
expectedTimeout: 10 * time.Second,
|
|
},
|
|
{
|
|
name: "custom timeout",
|
|
config: DNSWhitelistConfig{
|
|
ResolveTimeout: 30 * time.Second,
|
|
},
|
|
expectedRefresh: 5 * time.Minute,
|
|
expectedTimeout: 30 * time.Second,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
wl := NewDNSWhitelist(tt.config, logger)
|
|
|
|
if wl.config.RefreshInterval != tt.expectedRefresh {
|
|
t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh)
|
|
}
|
|
if wl.config.ResolveTimeout != tt.expectedTimeout {
|
|
t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout)
|
|
}
|
|
if !wl.config.AllowStale {
|
|
t.Error("AllowStale should default to true")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistContains tests the IP lookup functionality
|
|
func TestDNSWhitelistContains(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
|
|
|
// Manually add some entries for testing
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "test.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
|
IP: "10.0.0.50",
|
|
Source: "_sip._udp.provider.com",
|
|
SourceType: "srv",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
|
|
tests := []struct {
|
|
ip string
|
|
expected bool
|
|
}{
|
|
{"192.168.1.100", true},
|
|
{"10.0.0.50", true},
|
|
{"8.8.8.8", false},
|
|
{"192.168.1.101", false},
|
|
{"", false},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.ip, func(t *testing.T) {
|
|
if got := wl.Contains(tt.ip); got != tt.expected {
|
|
t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistExpiredEntries tests handling of expired entries
|
|
func TestDNSWhitelistExpiredEntries(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
|
|
t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) {
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
AllowStale: false,
|
|
}, logger)
|
|
// Override the default
|
|
wl.config.AllowStale = false
|
|
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "test.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now.Add(-1 * time.Hour),
|
|
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
|
}
|
|
|
|
if wl.Contains("192.168.1.100") {
|
|
t.Error("Expired entry should not match when AllowStale=false")
|
|
}
|
|
})
|
|
|
|
t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) {
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
AllowStale: true,
|
|
}, logger)
|
|
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "test.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now.Add(-1 * time.Hour),
|
|
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
|
}
|
|
|
|
if !wl.Contains("192.168.1.100") {
|
|
t.Error("Expired entry should match when AllowStale=true")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestDNSWhitelistGetSource tests source info retrieval
|
|
func TestDNSWhitelistGetSource(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
|
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "pbx.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
|
|
t.Run("existing IP returns source", func(t *testing.T) {
|
|
source := wl.GetSource("192.168.1.100")
|
|
if source == nil {
|
|
t.Fatal("GetSource returned nil for existing IP")
|
|
}
|
|
if source.Source != "pbx.example.com" {
|
|
t.Errorf("Source = %s, want pbx.example.com", source.Source)
|
|
}
|
|
if source.SourceType != "hostname" {
|
|
t.Errorf("SourceType = %s, want hostname", source.SourceType)
|
|
}
|
|
})
|
|
|
|
t.Run("non-existent IP returns nil", func(t *testing.T) {
|
|
source := wl.GetSource("8.8.8.8")
|
|
if source != nil {
|
|
t.Error("GetSource should return nil for non-existent IP")
|
|
}
|
|
})
|
|
}
|
|
|
|
// TestDNSWhitelistGetResolvedIPs tests listing all entries
|
|
func TestDNSWhitelistGetResolvedIPs(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
|
|
|
// Empty whitelist
|
|
entries := wl.GetResolvedIPs()
|
|
if len(entries) != 0 {
|
|
t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries))
|
|
}
|
|
|
|
// Add entries
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "host1.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
|
IP: "10.0.0.50",
|
|
Source: "_sip._udp.provider.com",
|
|
SourceType: "srv",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
|
|
entries = wl.GetResolvedIPs()
|
|
if len(entries) != 2 {
|
|
t.Errorf("Expected 2 entries, got %d", len(entries))
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistStats tests statistics reporting
|
|
func TestDNSWhitelistStats(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: []string{"host1.example.com", "host2.example.com"},
|
|
SRVRecords: []string{"_sip._udp.provider.com"},
|
|
RefreshInterval: 10 * time.Minute,
|
|
}, logger)
|
|
|
|
now := time.Now()
|
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
|
IP: "192.168.1.100",
|
|
Source: "host1.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{
|
|
IP: "192.168.1.101",
|
|
Source: "host2.example.com",
|
|
SourceType: "hostname",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
|
IP: "10.0.0.50",
|
|
Source: "_sip._udp.provider.com",
|
|
SourceType: "srv",
|
|
ResolvedAt: now,
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
|
|
stats := wl.Stats()
|
|
|
|
if stats["total_ips"] != 3 {
|
|
t.Errorf("total_ips = %v, want 3", stats["total_ips"])
|
|
}
|
|
if stats["hostname_ips"] != 2 {
|
|
t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"])
|
|
}
|
|
if stats["srv_ips"] != 1 {
|
|
t.Errorf("srv_ips = %v, want 1", stats["srv_ips"])
|
|
}
|
|
if stats["configured_hosts"] != 2 {
|
|
t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"])
|
|
}
|
|
if stats["configured_srv"] != 1 {
|
|
t.Errorf("configured_srv = %v, want 1", stats["configured_srv"])
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs
|
|
func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
|
|
|
// Test that direct IP addresses are handled correctly
|
|
tests := []struct {
|
|
input string
|
|
expected string
|
|
}{
|
|
{"192.168.1.100", "192.168.1.100"},
|
|
{"10.0.0.1", "10.0.0.1"},
|
|
{"::1", "::1"},
|
|
{"2001:db8::1", "2001:db8::1"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.input, func(t *testing.T) {
|
|
ips, err := wl.resolveHostname(nil, tt.input)
|
|
if err != nil {
|
|
t.Fatalf("resolveHostname(%s) error: %v", tt.input, err)
|
|
}
|
|
if len(ips) != 1 || ips[0] != tt.expected {
|
|
t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test)
|
|
// This test requires network connectivity
|
|
func TestDNSWhitelistRealDNS(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping DNS integration test in short mode")
|
|
}
|
|
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: []string{"localhost"},
|
|
RefreshInterval: 1 * time.Minute,
|
|
ResolveTimeout: 5 * time.Second,
|
|
}, logger)
|
|
|
|
// Start and let it do initial resolution
|
|
if err := wl.Start(); err != nil {
|
|
t.Fatalf("Failed to start DNS whitelist: %v", err)
|
|
}
|
|
defer wl.Stop()
|
|
|
|
// Give it a moment to resolve
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
// localhost should resolve to 127.0.0.1 or ::1
|
|
if !wl.Contains("127.0.0.1") && !wl.Contains("::1") {
|
|
entries := wl.GetResolvedIPs()
|
|
t.Errorf("Expected localhost to resolve, got entries: %+v", entries)
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistStartStop tests the lifecycle management
|
|
func TestDNSWhitelistStartStop(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS
|
|
RefreshInterval: 100 * time.Millisecond,
|
|
}, logger)
|
|
|
|
// Start
|
|
if err := wl.Start(); err != nil {
|
|
t.Fatalf("Start() error: %v", err)
|
|
}
|
|
|
|
// Should have the IP immediately
|
|
if !wl.Contains("127.0.0.1") {
|
|
t.Error("Should contain 127.0.0.1 after start")
|
|
}
|
|
|
|
// Stop should complete without hanging
|
|
done := make(chan struct{})
|
|
go func() {
|
|
wl.Stop()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Good, stopped cleanly
|
|
case <-time.After(2 * time.Second):
|
|
t.Error("Stop() took too long")
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistForceRefresh tests the manual refresh functionality
|
|
func TestDNSWhitelistForceRefresh(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: []string{"127.0.0.1"},
|
|
RefreshInterval: 1 * time.Hour, // Long interval
|
|
}, logger)
|
|
|
|
// Don't start the refresh loop, just manually refresh
|
|
wl.ForceRefresh()
|
|
|
|
if !wl.Contains("127.0.0.1") {
|
|
t.Error("ForceRefresh should resolve IPs")
|
|
}
|
|
}
|
|
|
|
// TestDNSWhitelistConcurrency tests concurrent access to the whitelist
|
|
func TestDNSWhitelistConcurrency(t *testing.T) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"},
|
|
RefreshInterval: 50 * time.Millisecond,
|
|
}, logger)
|
|
|
|
if err := wl.Start(); err != nil {
|
|
t.Fatalf("Start() error: %v", err)
|
|
}
|
|
defer wl.Stop()
|
|
|
|
// Run concurrent reads while refresh is happening
|
|
done := make(chan struct{})
|
|
go func() {
|
|
for i := 0; i < 1000; i++ {
|
|
wl.Contains("127.0.0.1")
|
|
wl.GetResolvedIPs()
|
|
wl.Stats()
|
|
}
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
// Good
|
|
case <-time.After(5 * time.Second):
|
|
t.Error("Concurrent operations took too long")
|
|
}
|
|
}
|
|
|
|
// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian
|
|
func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) {
|
|
// Create a SIPGuardian with DNS whitelist config
|
|
g := &SIPGuardian{
|
|
WhitelistHosts: []string{"127.0.0.1"},
|
|
WhitelistSRV: []string{},
|
|
}
|
|
|
|
// Initialize maps
|
|
g.bannedIPs = make(map[string]*BanEntry)
|
|
g.failureCounts = make(map[string]*failureTracker)
|
|
g.logger = zap.NewNop()
|
|
|
|
// Manually create DNS whitelist (normally done in Provision)
|
|
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: g.WhitelistHosts,
|
|
}, g.logger)
|
|
g.dnsWhitelist.ForceRefresh()
|
|
|
|
// Test IsWhitelisted
|
|
if !g.IsWhitelisted("127.0.0.1") {
|
|
t.Error("127.0.0.1 should be whitelisted via DNS")
|
|
}
|
|
|
|
if g.IsWhitelisted("8.8.8.8") {
|
|
t.Error("8.8.8.8 should NOT be whitelisted")
|
|
}
|
|
}
|
|
|
|
// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together
|
|
func TestSIPGuardianMixedWhitelist(t *testing.T) {
|
|
g := &SIPGuardian{
|
|
WhitelistCIDR: []string{"10.0.0.0/8"},
|
|
WhitelistHosts: []string{"127.0.0.1"},
|
|
}
|
|
|
|
// Initialize
|
|
g.bannedIPs = make(map[string]*BanEntry)
|
|
g.failureCounts = make(map[string]*failureTracker)
|
|
g.logger = zap.NewNop()
|
|
|
|
// Parse CIDR whitelist (normally done in Provision)
|
|
for _, cidr := range g.WhitelistCIDR {
|
|
_, network, err := net.ParseCIDR(cidr)
|
|
if err == nil {
|
|
g.whitelistNets = append(g.whitelistNets, network)
|
|
}
|
|
}
|
|
|
|
// Set up DNS whitelist
|
|
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
|
Hostnames: g.WhitelistHosts,
|
|
}, g.logger)
|
|
g.dnsWhitelist.ForceRefresh()
|
|
|
|
// Test CIDR whitelist
|
|
if !g.IsWhitelisted("10.1.2.3") {
|
|
t.Error("10.1.2.3 should be whitelisted via CIDR")
|
|
}
|
|
|
|
// Test DNS whitelist
|
|
if !g.IsWhitelisted("127.0.0.1") {
|
|
t.Error("127.0.0.1 should be whitelisted via DNS")
|
|
}
|
|
|
|
// Test non-whitelisted
|
|
if g.IsWhitelisted("8.8.8.8") {
|
|
t.Error("8.8.8.8 should NOT be whitelisted")
|
|
}
|
|
}
|
|
|
|
// BenchmarkDNSWhitelistContains benchmarks lookup performance
|
|
func BenchmarkDNSWhitelistContains(b *testing.B) {
|
|
logger := zap.NewNop()
|
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
|
|
|
// Add 1000 entries
|
|
now := time.Now()
|
|
for i := 0; i < 1000; i++ {
|
|
ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256))
|
|
wl.resolvedIPs[ip] = &ResolvedEntry{
|
|
IP: ip,
|
|
Source: "test.example.com",
|
|
ExpiresAt: now.Add(10 * time.Minute),
|
|
}
|
|
}
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
wl.Contains("192.168.0.100")
|
|
}
|
|
}
|