package sipguardian import ( "net" "testing" "time" "go.uber.org/zap" ) // TestDNSWhitelistConfig tests configuration defaults and validation func TestDNSWhitelistConfig(t *testing.T) { logger := zap.NewNop() tests := []struct { name string config DNSWhitelistConfig expectedRefresh time.Duration expectedTimeout time.Duration }{ { name: "empty config uses defaults", config: DNSWhitelistConfig{}, expectedRefresh: 5 * time.Minute, expectedTimeout: 10 * time.Second, }, { name: "custom refresh interval", config: DNSWhitelistConfig{ RefreshInterval: 10 * time.Minute, }, expectedRefresh: 10 * time.Minute, expectedTimeout: 10 * time.Second, }, { name: "custom timeout", config: DNSWhitelistConfig{ ResolveTimeout: 30 * time.Second, }, expectedRefresh: 5 * time.Minute, expectedTimeout: 30 * time.Second, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { wl := NewDNSWhitelist(tt.config, logger) if wl.config.RefreshInterval != tt.expectedRefresh { t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh) } if wl.config.ResolveTimeout != tt.expectedTimeout { t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout) } if !wl.config.AllowStale { t.Error("AllowStale should default to true") } }) } } // TestDNSWhitelistContains tests the IP lookup functionality func TestDNSWhitelistContains(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) // Manually add some entries for testing now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "test.example.com", SourceType: "hostname", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ IP: "10.0.0.50", Source: "_sip._udp.provider.com", SourceType: "srv", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } tests := []struct { ip string expected bool }{ {"192.168.1.100", true}, {"10.0.0.50", true}, {"8.8.8.8", false}, {"192.168.1.101", false}, {"", false}, } for _, tt := range tests { t.Run(tt.ip, func(t *testing.T) { if got := wl.Contains(tt.ip); got != tt.expected { t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected) } }) } } // TestDNSWhitelistExpiredEntries tests handling of expired entries func TestDNSWhitelistExpiredEntries(t *testing.T) { logger := zap.NewNop() t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) { wl := NewDNSWhitelist(DNSWhitelistConfig{ AllowStale: false, }, logger) // Override the default wl.config.AllowStale = false now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "test.example.com", SourceType: "hostname", ResolvedAt: now.Add(-1 * time.Hour), ExpiresAt: now.Add(-30 * time.Minute), // Already expired } if wl.Contains("192.168.1.100") { t.Error("Expired entry should not match when AllowStale=false") } }) t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) { wl := NewDNSWhitelist(DNSWhitelistConfig{ AllowStale: true, }, logger) now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "test.example.com", SourceType: "hostname", ResolvedAt: now.Add(-1 * time.Hour), ExpiresAt: now.Add(-30 * time.Minute), // Already expired } if !wl.Contains("192.168.1.100") { t.Error("Expired entry should match when AllowStale=true") } }) } // TestDNSWhitelistGetSource tests source info retrieval func TestDNSWhitelistGetSource(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "pbx.example.com", SourceType: "hostname", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } t.Run("existing IP returns source", func(t *testing.T) { source := wl.GetSource("192.168.1.100") if source == nil { t.Fatal("GetSource returned nil for existing IP") } if source.Source != "pbx.example.com" { t.Errorf("Source = %s, want pbx.example.com", source.Source) } if source.SourceType != "hostname" { t.Errorf("SourceType = %s, want hostname", source.SourceType) } }) t.Run("non-existent IP returns nil", func(t *testing.T) { source := wl.GetSource("8.8.8.8") if source != nil { t.Error("GetSource should return nil for non-existent IP") } }) } // TestDNSWhitelistGetResolvedIPs tests listing all entries func TestDNSWhitelistGetResolvedIPs(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) // Empty whitelist entries := wl.GetResolvedIPs() if len(entries) != 0 { t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries)) } // Add entries now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "host1.example.com", SourceType: "hostname", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ IP: "10.0.0.50", Source: "_sip._udp.provider.com", SourceType: "srv", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } entries = wl.GetResolvedIPs() if len(entries) != 2 { t.Errorf("Expected 2 entries, got %d", len(entries)) } } // TestDNSWhitelistStats tests statistics reporting func TestDNSWhitelistStats(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: []string{"host1.example.com", "host2.example.com"}, SRVRecords: []string{"_sip._udp.provider.com"}, RefreshInterval: 10 * time.Minute, }, logger) now := time.Now() wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ IP: "192.168.1.100", Source: "host1.example.com", SourceType: "hostname", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{ IP: "192.168.1.101", Source: "host2.example.com", SourceType: "hostname", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ IP: "10.0.0.50", Source: "_sip._udp.provider.com", SourceType: "srv", ResolvedAt: now, ExpiresAt: now.Add(10 * time.Minute), } stats := wl.Stats() if stats["total_ips"] != 3 { t.Errorf("total_ips = %v, want 3", stats["total_ips"]) } if stats["hostname_ips"] != 2 { t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"]) } if stats["srv_ips"] != 1 { t.Errorf("srv_ips = %v, want 1", stats["srv_ips"]) } if stats["configured_hosts"] != 2 { t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"]) } if stats["configured_srv"] != 1 { t.Errorf("configured_srv = %v, want 1", stats["configured_srv"]) } } // TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) // Test that direct IP addresses are handled correctly tests := []struct { input string expected string }{ {"192.168.1.100", "192.168.1.100"}, {"10.0.0.1", "10.0.0.1"}, {"::1", "::1"}, {"2001:db8::1", "2001:db8::1"}, } for _, tt := range tests { t.Run(tt.input, func(t *testing.T) { ips, err := wl.resolveHostname(nil, tt.input) if err != nil { t.Fatalf("resolveHostname(%s) error: %v", tt.input, err) } if len(ips) != 1 || ips[0] != tt.expected { t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected) } }) } } // TestDNSWhitelistRealDNS tests actual DNS resolution (integration test) // This test requires network connectivity func TestDNSWhitelistRealDNS(t *testing.T) { if testing.Short() { t.Skip("Skipping DNS integration test in short mode") } logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: []string{"localhost"}, RefreshInterval: 1 * time.Minute, ResolveTimeout: 5 * time.Second, }, logger) // Start and let it do initial resolution if err := wl.Start(); err != nil { t.Fatalf("Failed to start DNS whitelist: %v", err) } defer wl.Stop() // Give it a moment to resolve time.Sleep(100 * time.Millisecond) // localhost should resolve to 127.0.0.1 or ::1 if !wl.Contains("127.0.0.1") && !wl.Contains("::1") { entries := wl.GetResolvedIPs() t.Errorf("Expected localhost to resolve, got entries: %+v", entries) } } // TestDNSWhitelistStartStop tests the lifecycle management func TestDNSWhitelistStartStop(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS RefreshInterval: 100 * time.Millisecond, }, logger) // Start if err := wl.Start(); err != nil { t.Fatalf("Start() error: %v", err) } // Should have the IP immediately if !wl.Contains("127.0.0.1") { t.Error("Should contain 127.0.0.1 after start") } // Stop should complete without hanging done := make(chan struct{}) go func() { wl.Stop() close(done) }() select { case <-done: // Good, stopped cleanly case <-time.After(2 * time.Second): t.Error("Stop() took too long") } } // TestDNSWhitelistForceRefresh tests the manual refresh functionality func TestDNSWhitelistForceRefresh(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: []string{"127.0.0.1"}, RefreshInterval: 1 * time.Hour, // Long interval }, logger) // Don't start the refresh loop, just manually refresh wl.ForceRefresh() if !wl.Contains("127.0.0.1") { t.Error("ForceRefresh should resolve IPs") } } // TestDNSWhitelistConcurrency tests concurrent access to the whitelist func TestDNSWhitelistConcurrency(t *testing.T) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, RefreshInterval: 50 * time.Millisecond, }, logger) if err := wl.Start(); err != nil { t.Fatalf("Start() error: %v", err) } defer wl.Stop() // Run concurrent reads while refresh is happening done := make(chan struct{}) go func() { for i := 0; i < 1000; i++ { wl.Contains("127.0.0.1") wl.GetResolvedIPs() wl.Stats() } close(done) }() select { case <-done: // Good case <-time.After(5 * time.Second): t.Error("Concurrent operations took too long") } } // TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) { // Create a SIPGuardian with DNS whitelist config g := &SIPGuardian{ WhitelistHosts: []string{"127.0.0.1"}, WhitelistSRV: []string{}, } // Initialize maps g.bannedIPs = make(map[string]*BanEntry) g.failureCounts = make(map[string]*failureTracker) g.logger = zap.NewNop() // Manually create DNS whitelist (normally done in Provision) g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: g.WhitelistHosts, }, g.logger) g.dnsWhitelist.ForceRefresh() // Test IsWhitelisted if !g.IsWhitelisted("127.0.0.1") { t.Error("127.0.0.1 should be whitelisted via DNS") } if g.IsWhitelisted("8.8.8.8") { t.Error("8.8.8.8 should NOT be whitelisted") } } // TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together func TestSIPGuardianMixedWhitelist(t *testing.T) { g := &SIPGuardian{ WhitelistCIDR: []string{"10.0.0.0/8"}, WhitelistHosts: []string{"127.0.0.1"}, } // Initialize g.bannedIPs = make(map[string]*BanEntry) g.failureCounts = make(map[string]*failureTracker) g.logger = zap.NewNop() // Parse CIDR whitelist (normally done in Provision) for _, cidr := range g.WhitelistCIDR { _, network, err := net.ParseCIDR(cidr) if err == nil { g.whitelistNets = append(g.whitelistNets, network) } } // Set up DNS whitelist g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{ Hostnames: g.WhitelistHosts, }, g.logger) g.dnsWhitelist.ForceRefresh() // Test CIDR whitelist if !g.IsWhitelisted("10.1.2.3") { t.Error("10.1.2.3 should be whitelisted via CIDR") } // Test DNS whitelist if !g.IsWhitelisted("127.0.0.1") { t.Error("127.0.0.1 should be whitelisted via DNS") } // Test non-whitelisted if g.IsWhitelisted("8.8.8.8") { t.Error("8.8.8.8 should NOT be whitelisted") } } // BenchmarkDNSWhitelistContains benchmarks lookup performance func BenchmarkDNSWhitelistContains(b *testing.B) { logger := zap.NewNop() wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) // Add 1000 entries now := time.Now() for i := 0; i < 1000; i++ { ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256)) wl.resolvedIPs[ip] = &ResolvedEntry{ IP: ip, Source: "test.example.com", ExpiresAt: now.Add(10 * time.Minute), } } b.ResetTimer() for i := 0; i < b.N; i++ { wl.Contains("192.168.0.100") } }