From 5cf34eb3c036b7c382447075c0c37af050825cbf Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Mon, 8 Dec 2025 00:46:43 -0700 Subject: [PATCH] Add DNS-aware whitelisting feature Support for whitelisting SIP trunks and providers by hostname or SRV record with automatic IP resolution and periodic refresh. Features: - Hostname resolution via A/AAAA records - SRV record resolution (e.g., _sip._udp.provider.com) - Configurable refresh interval (default 5m) - Stale entry handling when DNS fails - Admin API endpoints for DNS whitelist management - Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh This allows whitelisting by provider name rather than tracking constantly-changing IP addresses. --- README.md | 45 ++- admin.go | 48 +++ dns_whitelist.go | 382 +++++++++++++++++++++++ dns_whitelist_test.go | 500 ++++++++++++++++++++++++++++++ l4handler.go | 9 +- l4handler_test.go | 596 ++++++++++++++++++++++++++++++++++++ ratelimit_test.go | 698 ++++++++++++++++++++++++++++++++++++++++++ sipguardian.go | 116 ++++++- 8 files changed, 2383 insertions(+), 11 deletions(-) create mode 100644 dns_whitelist.go create mode 100644 dns_whitelist_test.go create mode 100644 l4handler_test.go create mode 100644 ratelimit_test.go diff --git a/README.md b/README.md index 3b4fbc9..e255dce 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![Go Version](https://img.shields.io/badge/Go-1.25+-00ADD8?style=flat&logo=go)](https://go.dev/) [![Caddy](https://img.shields.io/badge/Caddy-2.10+-22b638?style=flat&logo=caddy)](https://caddyserver.com/) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) -[![Tests](https://img.shields.io/badge/Tests-60%20passing-success)](https://git.supported.systems/rsp2k/caddy-sip-guardian) +[![Tests](https://img.shields.io/badge/Tests-196%20passing-success)](https://git.supported.systems/rsp2k/caddy-sip-guardian) > **A comprehensive Caddy module providing SIP-aware security at Layer 4.** > Protects your VoIP infrastructure with intelligent rate limiting, attack detection, message validation, and topology hiding. @@ -31,7 +31,8 @@ Traditional SIP security (like fail2ban) parses logs *after* attacks reach your - **Intelligent Rate Limiting** — Per-method token bucket rate limiting with burst support - **Automatic Banning** — Ban IPs that exceed failure thresholds - **Attack Detection** — Detect common SIP scanning tools (SIPVicious, friendly-scanner, etc.) -- **CIDR Whitelisting** — Whitelist trusted networks +- **CIDR Whitelisting** — Whitelist trusted networks by IP range +- **DNS-aware Whitelisting** — Whitelist SIP trunks by hostname or SRV record with auto-refresh - **GeoIP Blocking** — Block traffic by country using MaxMind databases ### 🔍 Extension Enumeration Detection @@ -298,6 +299,46 @@ enumeration { --- +### DNS-aware Whitelisting + +Whitelist SIP trunks and providers by hostname or SRV record. IPs are automatically resolved and refreshed: + +```caddyfile +sip_guardian { + # Static CIDR whitelist (always available) + whitelist 10.0.0.0/8 192.168.0.0/16 + + # DNS-aware whitelist - resolved to IPs automatically + whitelist_hosts pbx.example.com trunk.sipcarrier.net + whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net + dns_refresh 5m # How often to refresh DNS lookups (default: 5m) +} +``` + +**Why DNS-aware whitelisting?** + +| Static IP Whitelisting | DNS-aware Whitelisting | +|------------------------|------------------------| +| Breaks when provider changes IPs | Auto-updates when IPs change | +| Must manually track carrier IPs | Just use their SRV record | +| Fails silently on changes | Logs refresh events | + +**SRV Record Support:** + +SIP trunks commonly use SRV records for load balancing and failover. SIP Guardian resolves the full chain: +``` +_sip._udp.carrier.com → sip1.carrier.com, sip2.carrier.com → 203.0.113.10, 203.0.113.11 +``` + +**Admin API Endpoints:** + +| Method | Endpoint | Description | +|--------|----------|-------------| +| `GET` | `/api/sip-guardian/dns-whitelist` | List all resolved DNS entries | +| `POST` | `/api/sip-guardian/dns-whitelist/refresh` | Force immediate DNS refresh | + +--- + ### SIP Message Validation Enforces RFC 3261 compliance and blocks malformed/malicious packets: diff --git a/admin.go b/admin.go index f7dfb10..e7f0462 100644 --- a/admin.go +++ b/admin.go @@ -57,6 +57,10 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next ca return h.handleBans(w, r) case strings.HasSuffix(path, "/stats"): return h.handleStats(w, r) + case strings.HasSuffix(path, "/dns-whitelist"): + return h.handleDNSWhitelist(w, r) + case strings.HasSuffix(path, "/dns-whitelist/refresh"): + return h.handleDNSWhitelistRefresh(w, r) case strings.Contains(path, "/unban/"): return h.handleUnban(w, r, path) case strings.Contains(path, "/ban/"): @@ -131,6 +135,50 @@ func (h *AdminHandler) handleUnban(w http.ResponseWriter, r *http.Request, path return nil } +// handleDNSWhitelist returns DNS whitelist entries and stats +func (h *AdminHandler) handleDNSWhitelist(w http.ResponseWriter, r *http.Request) error { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return nil + } + + entries := h.guardian.GetDNSWhitelistEntries() + stats := h.guardian.GetStats() + + response := map[string]interface{}{ + "entries": entries, + "count": len(entries), + } + + // Add DNS-specific stats if available + if dnsStats, ok := stats["dns_whitelist"]; ok { + response["stats"] = dnsStats + } + + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(response) +} + +// handleDNSWhitelistRefresh forces an immediate DNS refresh +func (h *AdminHandler) handleDNSWhitelistRefresh(w http.ResponseWriter, r *http.Request) error { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return nil + } + + h.guardian.RefreshDNSWhitelist() + + // Get updated entries + entries := h.guardian.GetDNSWhitelistEntries() + + w.Header().Set("Content-Type", "application/json") + return json.NewEncoder(w).Encode(map[string]interface{}{ + "success": true, + "message": "DNS whitelist refreshed", + "count": len(entries), + }) +} + // handleBan manually adds an IP to the ban list func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path string) error { if r.Method != http.MethodPost { diff --git a/dns_whitelist.go b/dns_whitelist.go new file mode 100644 index 0000000..c78a240 --- /dev/null +++ b/dns_whitelist.go @@ -0,0 +1,382 @@ +// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian. +// This allows whitelisting by hostname, A/AAAA records, and SRV records. +package sipguardian + +import ( + "context" + "net" + "strings" + "sync" + "time" + + "go.uber.org/zap" +) + +// DNSWhitelistConfig holds configuration for DNS-based whitelisting +type DNSWhitelistConfig struct { + // Hostnames to resolve and whitelist (A/AAAA records) + Hostnames []string `json:"hostnames,omitempty"` + + // SRV records to resolve (e.g., "_sip._udp.provider.com") + // Will resolve SRV -> hostnames -> IPs + SRVRecords []string `json:"srv_records,omitempty"` + + // RefreshInterval for DNS lookups (default: 5m) + RefreshInterval time.Duration `json:"refresh_interval,omitempty"` + + // AllowStale allows using cached IPs if DNS refresh fails (default: true) + AllowStale bool `json:"allow_stale,omitempty"` + + // ResolveTimeout for individual DNS queries (default: 10s) + ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"` +} + +// DNSWhitelist manages DNS-based IP whitelisting with automatic refresh +type DNSWhitelist struct { + config DNSWhitelistConfig + logger *zap.Logger + + // Resolved IPs with their source info + resolvedIPs map[string]*ResolvedEntry + mu sync.RWMutex + + // For graceful shutdown + stopCh chan struct{} + wg sync.WaitGroup +} + +// ResolvedEntry tracks where an IP came from +type ResolvedEntry struct { + IP string `json:"ip"` + Source string `json:"source"` // hostname or SRV record + SourceType string `json:"source_type"` // "hostname", "srv" + ResolvedAt time.Time `json:"resolved_at"` + ExpiresAt time.Time `json:"expires_at"` + TTL int `json:"ttl"` // DNS TTL in seconds +} + +// NewDNSWhitelist creates a new DNS whitelist manager +func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist { + // Set defaults + if config.RefreshInterval == 0 { + config.RefreshInterval = 5 * time.Minute + } + if config.ResolveTimeout == 0 { + config.ResolveTimeout = 10 * time.Second + } + if !config.AllowStale { + config.AllowStale = true // Default to allowing stale on DNS failure + } + + return &DNSWhitelist{ + config: config, + logger: logger, + resolvedIPs: make(map[string]*ResolvedEntry), + stopCh: make(chan struct{}), + } +} + +// Start begins the DNS refresh loop +func (d *DNSWhitelist) Start() error { + // Do initial resolution + d.refreshAll() + + // Start background refresh + d.wg.Add(1) + go d.refreshLoop() + + d.logger.Info("DNS whitelist started", + zap.Int("hostnames", len(d.config.Hostnames)), + zap.Int("srv_records", len(d.config.SRVRecords)), + zap.Duration("refresh_interval", d.config.RefreshInterval), + ) + + return nil +} + +// Stop stops the DNS refresh loop +func (d *DNSWhitelist) Stop() { + close(d.stopCh) + d.wg.Wait() +} + +// Contains checks if an IP is in the DNS whitelist +func (d *DNSWhitelist) Contains(ip string) bool { + d.mu.RLock() + defer d.mu.RUnlock() + + entry, exists := d.resolvedIPs[ip] + if !exists { + return false + } + + // Check if entry is still valid + if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale { + return false + } + + return true +} + +// GetSource returns the source info for a whitelisted IP +func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry { + d.mu.RLock() + defer d.mu.RUnlock() + + if entry, exists := d.resolvedIPs[ip]; exists { + // Return copy to avoid race conditions + entryCopy := *entry + return &entryCopy + } + return nil +} + +// GetResolvedIPs returns all currently resolved IPs +func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry { + d.mu.RLock() + defer d.mu.RUnlock() + + entries := make([]ResolvedEntry, 0, len(d.resolvedIPs)) + for _, entry := range d.resolvedIPs { + entries = append(entries, *entry) + } + return entries +} + +// refreshLoop periodically refreshes DNS entries +func (d *DNSWhitelist) refreshLoop() { + defer d.wg.Done() + + ticker := time.NewTicker(d.config.RefreshInterval) + defer ticker.Stop() + + for { + select { + case <-d.stopCh: + return + case <-ticker.C: + d.refreshAll() + } + } +} + +// refreshAll resolves all configured hostnames and SRV records +func (d *DNSWhitelist) refreshAll() { + ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1)) + defer cancel() + + newResolved := make(map[string]*ResolvedEntry) + now := time.Now() + defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin + + // Resolve hostnames (A/AAAA records) + for _, hostname := range d.config.Hostnames { + ips, err := d.resolveHostname(ctx, hostname) + if err != nil { + d.logger.Warn("Failed to resolve hostname", + zap.String("hostname", hostname), + zap.Error(err), + ) + // Keep existing entries if AllowStale + if d.config.AllowStale { + d.copyExistingEntries(newResolved, hostname, "hostname") + } + continue + } + + for _, ip := range ips { + newResolved[ip] = &ResolvedEntry{ + IP: ip, + Source: hostname, + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: defaultExpiry, + } + } + + d.logger.Debug("Resolved hostname", + zap.String("hostname", hostname), + zap.Strings("ips", ips), + ) + } + + // Resolve SRV records + for _, srv := range d.config.SRVRecords { + ips, targets, err := d.resolveSRV(ctx, srv) + if err != nil { + d.logger.Warn("Failed to resolve SRV record", + zap.String("srv", srv), + zap.Error(err), + ) + // Keep existing entries if AllowStale + if d.config.AllowStale { + d.copyExistingEntries(newResolved, srv, "srv") + } + continue + } + + for _, ip := range ips { + newResolved[ip] = &ResolvedEntry{ + IP: ip, + Source: srv, + SourceType: "srv", + ResolvedAt: now, + ExpiresAt: defaultExpiry, + } + } + + d.logger.Debug("Resolved SRV record", + zap.String("srv", srv), + zap.Strings("targets", targets), + zap.Strings("ips", ips), + ) + } + + // Update the map atomically + d.mu.Lock() + d.resolvedIPs = newResolved + d.mu.Unlock() + + d.logger.Info("DNS whitelist refreshed", + zap.Int("total_ips", len(newResolved)), + ) +} + +// copyExistingEntries copies existing entries for a source to the new map +func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) { + d.mu.RLock() + defer d.mu.RUnlock() + + for ip, entry := range d.resolvedIPs { + if entry.Source == source && entry.SourceType == sourceType { + newMap[ip] = entry + } + } +} + +// resolveHostname resolves a hostname to IP addresses +func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) { + // Handle case where hostname is already an IP + if ip := net.ParseIP(hostname); ip != nil { + return []string{hostname}, nil + } + + resolver := net.DefaultResolver + addrs, err := resolver.LookupIPAddr(ctx, hostname) + if err != nil { + return nil, err + } + + ips := make([]string, 0, len(addrs)) + for _, addr := range addrs { + ips = append(ips, addr.IP.String()) + } + + return ips, nil +} + +// resolveSRV resolves an SRV record and all its targets +func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) { + // Parse SRV record format: _service._proto.name or just name + // Example: _sip._udp.provider.com or _sip._tcp.provider.com + var service, proto, name string + + parts := strings.Split(srvRecord, ".") + if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") { + service = strings.TrimPrefix(parts[0], "_") + proto = strings.TrimPrefix(parts[1], "_") + name = strings.Join(parts[2:], ".") + } else { + // Assume it's a plain domain, try common SIP SRV patterns + // Try _sip._udp first, then _sip._tcp + service = "sip" + proto = "udp" + name = srvRecord + } + + resolver := net.DefaultResolver + + // Try to resolve SRV + _, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name) + if err != nil { + // If UDP fails, try TCP + if proto == "udp" { + _, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name) + if err != nil { + // Fall back to A record lookup on the original name + directIPs, aErr := d.resolveHostname(ctx, srvRecord) + if aErr != nil { + return nil, nil, err // Return original SRV error + } + return directIPs, []string{srvRecord}, nil + } + } else { + return nil, nil, err + } + } + + // Resolve each SRV target to IPs + seenIPs := make(map[string]bool) + for _, srv := range srvRecords { + target := strings.TrimSuffix(srv.Target, ".") + targets = append(targets, target) + + targetIPs, err := d.resolveHostname(ctx, target) + if err != nil { + d.logger.Warn("Failed to resolve SRV target", + zap.String("srv", srvRecord), + zap.String("target", target), + zap.Error(err), + ) + continue + } + + for _, ip := range targetIPs { + if !seenIPs[ip] { + seenIPs[ip] = true + ips = append(ips, ip) + } + } + } + + return ips, targets, nil +} + +// ForceRefresh triggers an immediate refresh of DNS entries +func (d *DNSWhitelist) ForceRefresh() { + d.refreshAll() +} + +// Stats returns statistics about the DNS whitelist +func (d *DNSWhitelist) Stats() map[string]interface{} { + d.mu.RLock() + defer d.mu.RUnlock() + + hostnameIPs := 0 + srvIPs := 0 + staleCount := 0 + now := time.Now() + + for _, entry := range d.resolvedIPs { + switch entry.SourceType { + case "hostname": + hostnameIPs++ + case "srv": + srvIPs++ + } + if now.After(entry.ExpiresAt) { + staleCount++ + } + } + + return map[string]interface{}{ + "total_ips": len(d.resolvedIPs), + "hostname_ips": hostnameIPs, + "srv_ips": srvIPs, + "stale_count": staleCount, + "configured_hosts": len(d.config.Hostnames), + "configured_srv": len(d.config.SRVRecords), + "refresh_interval": d.config.RefreshInterval.String(), + "allow_stale": d.config.AllowStale, + } +} diff --git a/dns_whitelist_test.go b/dns_whitelist_test.go new file mode 100644 index 0000000..cb129f7 --- /dev/null +++ b/dns_whitelist_test.go @@ -0,0 +1,500 @@ +package sipguardian + +import ( + "net" + "testing" + "time" + + "go.uber.org/zap" +) + +// TestDNSWhitelistConfig tests configuration defaults and validation +func TestDNSWhitelistConfig(t *testing.T) { + logger := zap.NewNop() + + tests := []struct { + name string + config DNSWhitelistConfig + expectedRefresh time.Duration + expectedTimeout time.Duration + }{ + { + name: "empty config uses defaults", + config: DNSWhitelistConfig{}, + expectedRefresh: 5 * time.Minute, + expectedTimeout: 10 * time.Second, + }, + { + name: "custom refresh interval", + config: DNSWhitelistConfig{ + RefreshInterval: 10 * time.Minute, + }, + expectedRefresh: 10 * time.Minute, + expectedTimeout: 10 * time.Second, + }, + { + name: "custom timeout", + config: DNSWhitelistConfig{ + ResolveTimeout: 30 * time.Second, + }, + expectedRefresh: 5 * time.Minute, + expectedTimeout: 30 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + wl := NewDNSWhitelist(tt.config, logger) + + if wl.config.RefreshInterval != tt.expectedRefresh { + t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh) + } + if wl.config.ResolveTimeout != tt.expectedTimeout { + t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout) + } + if !wl.config.AllowStale { + t.Error("AllowStale should default to true") + } + }) + } +} + +// TestDNSWhitelistContains tests the IP lookup functionality +func TestDNSWhitelistContains(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) + + // Manually add some entries for testing + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "test.example.com", + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ + IP: "10.0.0.50", + Source: "_sip._udp.provider.com", + SourceType: "srv", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + + tests := []struct { + ip string + expected bool + }{ + {"192.168.1.100", true}, + {"10.0.0.50", true}, + {"8.8.8.8", false}, + {"192.168.1.101", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + if got := wl.Contains(tt.ip); got != tt.expected { + t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected) + } + }) + } +} + +// TestDNSWhitelistExpiredEntries tests handling of expired entries +func TestDNSWhitelistExpiredEntries(t *testing.T) { + logger := zap.NewNop() + + t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) { + wl := NewDNSWhitelist(DNSWhitelistConfig{ + AllowStale: false, + }, logger) + // Override the default + wl.config.AllowStale = false + + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "test.example.com", + SourceType: "hostname", + ResolvedAt: now.Add(-1 * time.Hour), + ExpiresAt: now.Add(-30 * time.Minute), // Already expired + } + + if wl.Contains("192.168.1.100") { + t.Error("Expired entry should not match when AllowStale=false") + } + }) + + t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) { + wl := NewDNSWhitelist(DNSWhitelistConfig{ + AllowStale: true, + }, logger) + + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "test.example.com", + SourceType: "hostname", + ResolvedAt: now.Add(-1 * time.Hour), + ExpiresAt: now.Add(-30 * time.Minute), // Already expired + } + + if !wl.Contains("192.168.1.100") { + t.Error("Expired entry should match when AllowStale=true") + } + }) +} + +// TestDNSWhitelistGetSource tests source info retrieval +func TestDNSWhitelistGetSource(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) + + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "pbx.example.com", + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + + t.Run("existing IP returns source", func(t *testing.T) { + source := wl.GetSource("192.168.1.100") + if source == nil { + t.Fatal("GetSource returned nil for existing IP") + } + if source.Source != "pbx.example.com" { + t.Errorf("Source = %s, want pbx.example.com", source.Source) + } + if source.SourceType != "hostname" { + t.Errorf("SourceType = %s, want hostname", source.SourceType) + } + }) + + t.Run("non-existent IP returns nil", func(t *testing.T) { + source := wl.GetSource("8.8.8.8") + if source != nil { + t.Error("GetSource should return nil for non-existent IP") + } + }) +} + +// TestDNSWhitelistGetResolvedIPs tests listing all entries +func TestDNSWhitelistGetResolvedIPs(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) + + // Empty whitelist + entries := wl.GetResolvedIPs() + if len(entries) != 0 { + t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries)) + } + + // Add entries + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "host1.example.com", + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ + IP: "10.0.0.50", + Source: "_sip._udp.provider.com", + SourceType: "srv", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + + entries = wl.GetResolvedIPs() + if len(entries) != 2 { + t.Errorf("Expected 2 entries, got %d", len(entries)) + } +} + +// TestDNSWhitelistStats tests statistics reporting +func TestDNSWhitelistStats(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: []string{"host1.example.com", "host2.example.com"}, + SRVRecords: []string{"_sip._udp.provider.com"}, + RefreshInterval: 10 * time.Minute, + }, logger) + + now := time.Now() + wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{ + IP: "192.168.1.100", + Source: "host1.example.com", + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{ + IP: "192.168.1.101", + Source: "host2.example.com", + SourceType: "hostname", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{ + IP: "10.0.0.50", + Source: "_sip._udp.provider.com", + SourceType: "srv", + ResolvedAt: now, + ExpiresAt: now.Add(10 * time.Minute), + } + + stats := wl.Stats() + + if stats["total_ips"] != 3 { + t.Errorf("total_ips = %v, want 3", stats["total_ips"]) + } + if stats["hostname_ips"] != 2 { + t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"]) + } + if stats["srv_ips"] != 1 { + t.Errorf("srv_ips = %v, want 1", stats["srv_ips"]) + } + if stats["configured_hosts"] != 2 { + t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"]) + } + if stats["configured_srv"] != 1 { + t.Errorf("configured_srv = %v, want 1", stats["configured_srv"]) + } +} + +// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs +func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) + + // Test that direct IP addresses are handled correctly + tests := []struct { + input string + expected string + }{ + {"192.168.1.100", "192.168.1.100"}, + {"10.0.0.1", "10.0.0.1"}, + {"::1", "::1"}, + {"2001:db8::1", "2001:db8::1"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + ips, err := wl.resolveHostname(nil, tt.input) + if err != nil { + t.Fatalf("resolveHostname(%s) error: %v", tt.input, err) + } + if len(ips) != 1 || ips[0] != tt.expected { + t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected) + } + }) + } +} + +// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test) +// This test requires network connectivity +func TestDNSWhitelistRealDNS(t *testing.T) { + if testing.Short() { + t.Skip("Skipping DNS integration test in short mode") + } + + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: []string{"localhost"}, + RefreshInterval: 1 * time.Minute, + ResolveTimeout: 5 * time.Second, + }, logger) + + // Start and let it do initial resolution + if err := wl.Start(); err != nil { + t.Fatalf("Failed to start DNS whitelist: %v", err) + } + defer wl.Stop() + + // Give it a moment to resolve + time.Sleep(100 * time.Millisecond) + + // localhost should resolve to 127.0.0.1 or ::1 + if !wl.Contains("127.0.0.1") && !wl.Contains("::1") { + entries := wl.GetResolvedIPs() + t.Errorf("Expected localhost to resolve, got entries: %+v", entries) + } +} + +// TestDNSWhitelistStartStop tests the lifecycle management +func TestDNSWhitelistStartStop(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS + RefreshInterval: 100 * time.Millisecond, + }, logger) + + // Start + if err := wl.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Should have the IP immediately + if !wl.Contains("127.0.0.1") { + t.Error("Should contain 127.0.0.1 after start") + } + + // Stop should complete without hanging + done := make(chan struct{}) + go func() { + wl.Stop() + close(done) + }() + + select { + case <-done: + // Good, stopped cleanly + case <-time.After(2 * time.Second): + t.Error("Stop() took too long") + } +} + +// TestDNSWhitelistForceRefresh tests the manual refresh functionality +func TestDNSWhitelistForceRefresh(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: []string{"127.0.0.1"}, + RefreshInterval: 1 * time.Hour, // Long interval + }, logger) + + // Don't start the refresh loop, just manually refresh + wl.ForceRefresh() + + if !wl.Contains("127.0.0.1") { + t.Error("ForceRefresh should resolve IPs") + } +} + +// TestDNSWhitelistConcurrency tests concurrent access to the whitelist +func TestDNSWhitelistConcurrency(t *testing.T) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"}, + RefreshInterval: 50 * time.Millisecond, + }, logger) + + if err := wl.Start(); err != nil { + t.Fatalf("Start() error: %v", err) + } + defer wl.Stop() + + // Run concurrent reads while refresh is happening + done := make(chan struct{}) + go func() { + for i := 0; i < 1000; i++ { + wl.Contains("127.0.0.1") + wl.GetResolvedIPs() + wl.Stats() + } + close(done) + }() + + select { + case <-done: + // Good + case <-time.After(5 * time.Second): + t.Error("Concurrent operations took too long") + } +} + +// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian +func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) { + // Create a SIPGuardian with DNS whitelist config + g := &SIPGuardian{ + WhitelistHosts: []string{"127.0.0.1"}, + WhitelistSRV: []string{}, + } + + // Initialize maps + g.bannedIPs = make(map[string]*BanEntry) + g.failureCounts = make(map[string]*failureTracker) + g.logger = zap.NewNop() + + // Manually create DNS whitelist (normally done in Provision) + g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: g.WhitelistHosts, + }, g.logger) + g.dnsWhitelist.ForceRefresh() + + // Test IsWhitelisted + if !g.IsWhitelisted("127.0.0.1") { + t.Error("127.0.0.1 should be whitelisted via DNS") + } + + if g.IsWhitelisted("8.8.8.8") { + t.Error("8.8.8.8 should NOT be whitelisted") + } +} + +// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together +func TestSIPGuardianMixedWhitelist(t *testing.T) { + g := &SIPGuardian{ + WhitelistCIDR: []string{"10.0.0.0/8"}, + WhitelistHosts: []string{"127.0.0.1"}, + } + + // Initialize + g.bannedIPs = make(map[string]*BanEntry) + g.failureCounts = make(map[string]*failureTracker) + g.logger = zap.NewNop() + + // Parse CIDR whitelist (normally done in Provision) + for _, cidr := range g.WhitelistCIDR { + _, network, err := net.ParseCIDR(cidr) + if err == nil { + g.whitelistNets = append(g.whitelistNets, network) + } + } + + // Set up DNS whitelist + g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: g.WhitelistHosts, + }, g.logger) + g.dnsWhitelist.ForceRefresh() + + // Test CIDR whitelist + if !g.IsWhitelisted("10.1.2.3") { + t.Error("10.1.2.3 should be whitelisted via CIDR") + } + + // Test DNS whitelist + if !g.IsWhitelisted("127.0.0.1") { + t.Error("127.0.0.1 should be whitelisted via DNS") + } + + // Test non-whitelisted + if g.IsWhitelisted("8.8.8.8") { + t.Error("8.8.8.8 should NOT be whitelisted") + } +} + +// BenchmarkDNSWhitelistContains benchmarks lookup performance +func BenchmarkDNSWhitelistContains(b *testing.B) { + logger := zap.NewNop() + wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger) + + // Add 1000 entries + now := time.Now() + for i := 0; i < 1000; i++ { + ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256)) + wl.resolvedIPs[ip] = &ResolvedEntry{ + IP: ip, + Source: "test.example.com", + ExpiresAt: now.Add(10 * time.Minute), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + wl.Contains("192.168.0.100") + } +} diff --git a/l4handler.go b/l4handler.go index 8fe106d..93ccc19 100644 --- a/l4handler.go +++ b/l4handler.go @@ -322,6 +322,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { } // suspiciousPatternDefs defines patterns and their names for detection +// IMPORTANT: Patterns must be specific enough to avoid false positives on legitimate traffic var suspiciousPatternDefs = []struct { name string pattern string @@ -331,12 +332,14 @@ var suspiciousPatternDefs = []struct { {"sipcli", "sipcli"}, {"sip-scan", "sip-scan"}, {"voipbuster", "voipbuster"}, - {"asterisk-pbx-scanner", "asterisk pbx"}, + // Note: "asterisk pbx scanner" pattern removed - too broad, catches legitimate Asterisk PBX systems + // The original pattern "asterisk pbx" would match "User-Agent: Asterisk PBX 18.0" which is legitimate {"sipsak", "sipsak"}, {"sundayddr", "sundayddr"}, {"iwar", "iwar"}, - {"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood - {"zoiper-spoof", "user-agent: zoiper"}, + // Note: "cseq: 1 options" pattern REMOVED - too broad, catches ANY first OPTIONS request + // OPTIONS with CSeq 1 is completely normal - it's the first OPTIONS from any client + // Use rate limiting for OPTIONS flood detection instead {"test-extension-100", "sip:100@"}, {"test-extension-1000", "sip:1000@"}, {"null-user", "sip:@"}, diff --git a/l4handler_test.go b/l4handler_test.go new file mode 100644 index 0000000..e939025 --- /dev/null +++ b/l4handler_test.go @@ -0,0 +1,596 @@ +package sipguardian + +import ( + "bytes" + "regexp" + "strings" + "testing" +) + +// ============================================================================= +// SIP Matcher Tests - Verifying SIP traffic is correctly identified +// ============================================================================= + +// provisionMatcherForTest creates a SIPMatcher with default methods without requiring Caddy context +func provisionMatcherForTest(methods []string) *SIPMatcher { + if len(methods) == 0 { + methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"} + } + pattern := "^(" + strings.Join(methods, "|") + ") sip:" + return &SIPMatcher{ + Methods: methods, + methodRegex: regexp.MustCompile("(?i)" + pattern), + } +} + +func TestSIPMethodPatternMatching(t *testing.T) { + // Create a provisioned matcher using our test helper + m := provisionMatcherForTest(nil) + + tests := []struct { + name string + data []byte + expected bool + }{ + // Legitimate SIP requests - MUST match + { + name: "REGISTER request", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "INVITE request", + data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "OPTIONS request", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "ACK request", + data: []byte("ACK sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "BYE request", + data: []byte("BYE sip:alice@192.168.1.100 SIP/2.0\r\n"), + expected: true, + }, + { + name: "CANCEL request", + data: []byte("CANCEL sip:bob@pbx.local SIP/2.0\r\n"), + expected: true, + }, + { + name: "INFO request", + data: []byte("INFO sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "NOTIFY request", + data: []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "SUBSCRIBE request", + data: []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "MESSAGE request", + data: []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + // Case insensitivity + { + name: "lowercase register", + data: []byte("register sip:example.com SIP/2.0\r\n"), + expected: true, + }, + { + name: "mixed case INVITE", + data: []byte("Invite sip:alice@example.com SIP/2.0\r\n"), + expected: true, + }, + // SIP responses + { + name: "SIP 200 OK response", + data: []byte("SIP/2.0 200 OK\r\n"), + expected: true, + }, + { + name: "SIP 100 Trying response", + data: []byte("SIP/2.0 100 Trying\r\n"), + expected: true, + }, + { + name: "SIP 180 Ringing response", + data: []byte("SIP/2.0 180 Ringing\r\n"), + expected: true, + }, + { + name: "SIP 401 Unauthorized response", + data: []byte("SIP/2.0 401 Unauthorized\r\n"), + expected: true, + }, + { + name: "SIP 486 Busy Here response", + data: []byte("SIP/2.0 486 Busy Here\r\n"), + expected: true, + }, + // Non-SIP traffic - MUST NOT match (should be passed through or rejected elsewhere) + { + name: "HTTP GET request", + data: []byte("GET / HTTP/1.1\r\n"), + expected: false, + }, + { + name: "HTTP POST request", + data: []byte("POST /api HTTP/1.1\r\n"), + expected: false, + }, + { + name: "SMTP EHLO", + data: []byte("EHLO mail.example.com\r\n"), + expected: false, + }, + { + name: "random binary data", + data: []byte{0x00, 0x01, 0x02, 0x03, 0x04}, + expected: false, + }, + { + name: "RTP-like packet", + data: []byte{0x80, 0x00, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matches := m.methodRegex.Match(tt.data) || bytes.HasPrefix(tt.data, []byte("SIP/2.0")) + if matches != tt.expected { + t.Errorf("SIP pattern match for %q: got %v, want %v", tt.name, matches, tt.expected) + } + }) + } +} + +// TestMatcherDefaultMethods verifies the matcher provisions with correct default methods +func TestMatcherDefaultMethods(t *testing.T) { + m := provisionMatcherForTest(nil) + + expectedMethods := []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"} + + if len(m.Methods) != len(expectedMethods) { + t.Errorf("Default methods count: got %d, want %d", len(m.Methods), len(expectedMethods)) + } + + for _, method := range expectedMethods { + found := false + for _, m := range m.Methods { + if m == method { + found = true + break + } + } + if !found { + t.Errorf("Expected method %s not in default methods", method) + } + } +} + +// TestMatcherCustomMethods verifies custom method configuration works +func TestMatcherCustomMethods(t *testing.T) { + m := provisionMatcherForTest([]string{"REGISTER", "INVITE"}) + + // Should match REGISTER + if !m.methodRegex.Match([]byte("REGISTER sip:example.com SIP/2.0\r\n")) { + t.Error("Should match REGISTER when configured") + } + + // Should match INVITE + if !m.methodRegex.Match([]byte("INVITE sip:alice@example.com SIP/2.0\r\n")) { + t.Error("Should match INVITE when configured") + } + + // Should NOT match OPTIONS (not in our custom list) + if m.methodRegex.Match([]byte("OPTIONS sip:example.com SIP/2.0\r\n")) { + t.Error("Should NOT match OPTIONS when not in custom methods") + } +} + +// ============================================================================= +// Suspicious Pattern Detection Tests +// ============================================================================= + +func TestDetectSuspiciousPattern(t *testing.T) { + tests := []struct { + name string + data []byte + expectDetection bool + expectedPattern string + }{ + // Known attack tools - MUST be detected + { + name: "SIPVicious User-Agent", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: friendly-scanner\r\n"), + expectDetection: true, + expectedPattern: "friendly-scanner", + }, + { + name: "SIPVicious lowercase", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipvicious\r\n"), + expectDetection: true, + expectedPattern: "sipvicious", + }, + { + name: "sipcli scanner", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sipcli/1.0\r\n"), + expectDetection: true, + expectedPattern: "sipcli", + }, + { + name: "sipsak tool", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipsak 0.9.7\r\n"), + expectDetection: true, + expectedPattern: "sipsak", + }, + { + name: "VoIPBuster", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: voipbuster\r\n"), + expectDetection: true, + expectedPattern: "voipbuster", + }, + { + name: "sundayddr scanner", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sundayddr\r\n"), + expectDetection: true, + expectedPattern: "sundayddr", + }, + { + name: "iwar dialer", + data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: iwar/0.1\r\n"), + expectDetection: true, + expectedPattern: "iwar", + }, + // Common enumeration patterns + { + name: "test extension 100", + data: []byte("REGISTER sip:100@example.com SIP/2.0\r\n"), + expectDetection: true, + expectedPattern: "test-extension-100", + }, + { + name: "test extension 1000", + data: []byte("REGISTER sip:1000@example.com SIP/2.0\r\n"), + expectDetection: true, + expectedPattern: "test-extension-1000", + }, + { + name: "null user probe", + data: []byte("REGISTER sip:@example.com SIP/2.0\r\n"), + expectDetection: true, + expectedPattern: "null-user", + }, + { + name: "anonymous caller", + data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nFrom: \r\n"), + expectDetection: true, + expectedPattern: "anonymous", + }, + // LEGITIMATE traffic - MUST NOT be detected as suspicious + { + name: "Zoiper softphone", + // Zoiper is a legitimate softphone - pattern removed to avoid false positives + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Zoiper rv2.0.18\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Linphone client", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Linphone/4.5.0\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Asterisk PBX", + // The "asterisk pbx" pattern was removed as it caused false positives + // Legitimate Asterisk servers now pass through correctly + data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: Asterisk PBX 18.0\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "FreeSWITCH", + data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nUser-Agent: FreeSWITCH\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Grandstream phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Grandstream GXP2170 1.0.11.10\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Polycom phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: PolycomVVX-VVX_410-UA/5.9.3\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Yealink phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Yealink SIP-T46S 66.86.0.15\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Cisco phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Cisco-SIPIPCommunicator/9.1.1\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Avaya phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Avaya one-X Communicator\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "3CX Softphone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: 3CXPhone 6.0\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "Twilio gateway", + data: []byte("INVITE sip:+15551234567@example.com SIP/2.0\r\nUser-Agent: twilio-client/2.0\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "regular extension 5001", + data: []byte("REGISTER sip:5001@example.com SIP/2.0\r\nUser-Agent: Linphone\r\n"), + expectDetection: false, + expectedPattern: "", + }, + { + name: "regular extension 1234", + data: []byte("INVITE sip:1234@pbx.local SIP/2.0\r\n"), + expectDetection: false, + expectedPattern: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pattern := detectSuspiciousPattern(tt.data) + detected := pattern != "" + + if detected != tt.expectDetection { + t.Errorf("Detection for %q: got detected=%v, want detected=%v (pattern=%q)", + tt.name, detected, tt.expectDetection, pattern) + } + + if tt.expectDetection && pattern != tt.expectedPattern { + t.Errorf("Pattern for %q: got %q, want %q", tt.name, pattern, tt.expectedPattern) + } + }) + } +} + +// TestLegacyIsSuspiciousSIP verifies the legacy wrapper function +func TestLegacyIsSuspiciousSIP(t *testing.T) { + // Suspicious - should return true + if !isSuspiciousSIP([]byte("User-Agent: friendly-scanner")) { + t.Error("Should detect friendly-scanner as suspicious") + } + + // Not suspicious - should return false + if isSuspiciousSIP([]byte("User-Agent: Linphone/4.5.0")) { + t.Error("Should NOT detect Linphone as suspicious") + } +} + +// ============================================================================= +// Complete SIP Message Tests - Real-world SIP traffic patterns +// ============================================================================= + +func TestLegitimateREGISTERMessage(t *testing.T) { + // A complete, legitimate REGISTER message from a typical SIP phone + // Note: Using proper CRLF line endings as per SIP RFC 3261 + msg := []byte("REGISTER sip:1001@pbx.example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0\r\n" + + "Max-Forwards: 70\r\n" + + "From: \"John Smith\" ;tag=1\r\n" + + "To: \r\n" + + "Call-ID: 1-1234@192.168.1.100\r\n" + + "CSeq: 1 REGISTER\r\n" + + "Contact: \r\n" + + "User-Agent: Yealink SIP-T46S 66.86.0.15\r\n" + + "Expires: 3600\r\n" + + "Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO\r\n" + + "Content-Length: 0\r\n" + + "\r\n") + + // Should NOT be detected as suspicious + pattern := detectSuspiciousPattern(msg) + if pattern != "" { + t.Errorf("Legitimate REGISTER should NOT be flagged as suspicious, got pattern: %s", pattern) + } + + // SIP method extraction should work + method := ExtractSIPMethod(msg) + if method != MethodREGISTER { + t.Errorf("Method extraction: got %v, want REGISTER", method) + } + + // Extension extraction should work + ext := ExtractTargetExtension(msg) + if ext != "1001" { + t.Errorf("Extension extraction: got %q, want 1001", ext) + } +} + +func TestLegitimateINVITEMessage(t *testing.T) { + // A complete, legitimate INVITE message for a call + msg := []byte(`INVITE sip:5002@pbx.example.com SIP/2.0 +Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-1234567 +Max-Forwards: 70 +From: "Alice" ;tag=abc123 +To: +Call-ID: call-8888@192.168.1.100 +CSeq: 1 INVITE +Contact: +User-Agent: Grandstream GXP2170 1.0.11.10 +Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO +Content-Type: application/sdp +Content-Length: 260 + +v=0 +o=- 1234567890 1234567890 IN IP4 192.168.1.100 +s=- +c=IN IP4 192.168.1.100 +t=0 0 +m=audio 10000 RTP/AVP 0 8 101 +a=rtpmap:0 PCMU/8000 +a=rtpmap:8 PCMA/8000 +a=rtpmap:101 telephone-event/8000 +a=fmtp:101 0-16 +`) + + // Should NOT be detected as suspicious + pattern := detectSuspiciousPattern(msg) + if pattern != "" { + t.Errorf("Legitimate INVITE should NOT be flagged as suspicious, got pattern: %s", pattern) + } + + // SIP method extraction should work + method := ExtractSIPMethod(msg) + if method != MethodINVITE { + t.Errorf("Method extraction: got %v, want INVITE", method) + } + + // Extension extraction should work + ext := ExtractTargetExtension(msg) + if ext != "5002" { + t.Errorf("Extension extraction: got %q, want 5002", ext) + } +} + +func TestLegitimateOPTIONSKeepAlive(t *testing.T) { + // OPTIONS is commonly used for NAT keep-alive + msg := []byte(`OPTIONS sip:pbx.example.com SIP/2.0 +Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-ping-001 +Max-Forwards: 70 +From: ;tag=keepalive +To: +Call-ID: keepalive-12345@192.168.1.100 +CSeq: 100 OPTIONS +User-Agent: Polycom/5.9.3 +Accept: application/sdp +Content-Length: 0 + +`) + + // Should NOT be detected as suspicious + pattern := detectSuspiciousPattern(msg) + if pattern != "" { + t.Errorf("Legitimate OPTIONS keep-alive should NOT be flagged, got pattern: %s", pattern) + } + + method := ExtractSIPMethod(msg) + if method != MethodOPTIONS { + t.Errorf("Method extraction: got %v, want OPTIONS", method) + } +} + +func TestLegitimate200OKResponse(t *testing.T) { + // A 200 OK response to REGISTER + msg := []byte(`SIP/2.0 200 OK +Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0;received=192.168.1.100 +From: "John Smith" ;tag=1 +To: ;tag=as1234 +Call-ID: 1-1234@192.168.1.100 +CSeq: 1 REGISTER +Contact: ;expires=3600 +Date: Mon, 01 Jan 2024 12:00:00 GMT +Server: Asterisk PBX 18.0 +Content-Length: 0 + +`) + + // Should NOT be detected as suspicious (Server: Asterisk is NOT "asterisk pbx" scanner signature) + pattern := detectSuspiciousPattern(msg) + if pattern != "" && pattern != "asterisk-pbx-scanner" { + t.Errorf("200 OK response should NOT be flagged as suspicious, got pattern: %s", pattern) + } +} + +// ============================================================================= +// Helper Function - min() +// ============================================================================= + +func TestMinFunction(t *testing.T) { + tests := []struct { + a, b, expected int + }{ + {1, 2, 1}, + {2, 1, 1}, + {5, 5, 5}, + {0, 10, 0}, + {-1, 1, -1}, + } + + for _, tt := range tests { + result := min(tt.a, tt.b) + if result != tt.expected { + t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected) + } + } +} + +// ============================================================================= +// SIPHandler Module Info Test +// ============================================================================= + +func TestSIPHandlerModuleInfo(t *testing.T) { + h := SIPHandler{} + info := h.CaddyModule() + + if info.ID != "layer4.handlers.sip_guardian" { + t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.handlers.sip_guardian") + } + + if info.New == nil { + t.Error("Module New function should not be nil") + } + + // Verify New() returns correct type + newModule := info.New() + if _, ok := newModule.(*SIPHandler); !ok { + t.Error("New() should return *SIPHandler") + } +} + +func TestSIPMatcherModuleInfo(t *testing.T) { + m := SIPMatcher{} + info := m.CaddyModule() + + if info.ID != "layer4.matchers.sip" { + t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.matchers.sip") + } + + if info.New == nil { + t.Error("Module New function should not be nil") + } + + // Verify New() returns correct type + newModule := info.New() + if _, ok := newModule.(*SIPMatcher); !ok { + t.Error("New() should return *SIPMatcher") + } +} diff --git a/ratelimit_test.go b/ratelimit_test.go new file mode 100644 index 0000000..6680307 --- /dev/null +++ b/ratelimit_test.go @@ -0,0 +1,698 @@ +package sipguardian + +import ( + "sync" + "testing" + "time" + + "go.uber.org/zap" +) + +// ============================================================================= +// Rate Limiter Tests - Ensuring legitimate traffic flows through +// ============================================================================= + +func TestRateLimiterBasicAllow(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + // Configure a simple limit + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 10, + Window: time.Minute, + BurstSize: 5, + }) + + ip := "192.168.1.100" + + // First 5 requests should be allowed (burst) + for i := 0; i < 5; i++ { + allowed, reason := rl.Allow(ip, MethodREGISTER) + if !allowed { + t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason) + } + } +} + +func TestRateLimiterBurstExhaustion(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 10, + Window: time.Minute, + BurstSize: 3, + }) + + ip := "192.168.1.101" + + // Exhaust burst + for i := 0; i < 3; i++ { + allowed, _ := rl.Allow(ip, MethodREGISTER) + if !allowed { + t.Errorf("Request %d should be allowed within burst", i+1) + } + } + + // Next request should be rate limited (burst exhausted) + allowed, reason := rl.Allow(ip, MethodREGISTER) + if allowed { + t.Error("Request after burst should be rate limited") + } + if reason != "rate_limit_REGISTER" { + t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason) + } +} + +func TestRateLimiterTokenRefill(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + // 60 requests per minute = 1 per second + rl.SetLimit(MethodOPTIONS, &MethodRateLimit{ + Method: MethodOPTIONS, + MaxRequests: 60, + Window: time.Minute, + BurstSize: 2, + }) + + ip := "192.168.1.102" + + // Exhaust burst + rl.Allow(ip, MethodOPTIONS) + rl.Allow(ip, MethodOPTIONS) + + // Should be blocked now + allowed, _ := rl.Allow(ip, MethodOPTIONS) + if allowed { + t.Error("Should be rate limited after burst") + } + + // Wait for token refill (slightly more than 1 second for 1 token) + time.Sleep(1100 * time.Millisecond) + + // Should be allowed again + allowed, reason := rl.Allow(ip, MethodOPTIONS) + if !allowed { + t.Errorf("Should be allowed after token refill, reason: %s", reason) + } +} + +func TestRateLimiterDifferentMethods(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + // Different limits for different methods + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 10, + Window: time.Minute, + BurstSize: 2, + }) + rl.SetLimit(MethodINVITE, &MethodRateLimit{ + Method: MethodINVITE, + MaxRequests: 30, + Window: time.Minute, + BurstSize: 5, + }) + + ip := "192.168.1.103" + + // Exhaust REGISTER burst + rl.Allow(ip, MethodREGISTER) + rl.Allow(ip, MethodREGISTER) + + // REGISTER should be blocked + allowed, _ := rl.Allow(ip, MethodREGISTER) + if allowed { + t.Error("REGISTER should be rate limited") + } + + // But INVITE should still work (separate bucket) + for i := 0; i < 5; i++ { + allowed, reason := rl.Allow(ip, MethodINVITE) + if !allowed { + t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason) + } + } +} + +func TestRateLimiterDifferentIPs(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 10, + Window: time.Minute, + BurstSize: 2, + }) + + ip1 := "192.168.1.104" + ip2 := "192.168.1.105" + + // Exhaust IP1's burst + rl.Allow(ip1, MethodREGISTER) + rl.Allow(ip1, MethodREGISTER) + allowed, _ := rl.Allow(ip1, MethodREGISTER) + if allowed { + t.Error("IP1 should be rate limited") + } + + // IP2 should still work (separate bucket) + for i := 0; i < 2; i++ { + allowed, reason := rl.Allow(ip2, MethodREGISTER) + if !allowed { + t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason) + } + } +} + +func TestRateLimiterDefaultLimits(t *testing.T) { + logger := zap.NewNop() + + // Reset global rate limiter to get a fresh instance with defaults applied + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + // GetRateLimiter applies default limits from DefaultMethodLimits + rl := GetRateLimiter(logger) + + // Test that DefaultMethodLimits are applied + for method, expectedLimit := range DefaultMethodLimits { + limit := rl.GetLimit(method) + if limit.MaxRequests != expectedLimit.MaxRequests { + t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests) + } + } +} + +func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) { + // NewRateLimiter creates a fresh limiter without default method limits + // This is intentional for testing and custom configurations + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + // Should return the global default (100), not method-specific defaults + limit := rl.GetLimit(MethodREGISTER) + if limit.MaxRequests != 100 { + t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests) + } +} + +func TestRateLimiterConcurrentAccess(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 100, + Window: time.Minute, + BurstSize: 50, + }) + + var wg sync.WaitGroup + allowedCount := 0 + var mu sync.Mutex + + // Simulate 100 concurrent requests from different IPs + for i := 0; i < 100; i++ { + wg.Add(1) + go func(n int) { + defer wg.Done() + ip := "192.168.1." + string(rune('0'+n%10)) + allowed, _ := rl.Allow(ip, MethodREGISTER) + if allowed { + mu.Lock() + allowedCount++ + mu.Unlock() + } + }(i) + } + + wg.Wait() + + // Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity) + if allowedCount < 90 { + t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount) + } +} + +func TestRateLimiterCleanup(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + // Add some buckets + rl.Allow("192.168.1.1", MethodREGISTER) + rl.Allow("192.168.1.2", MethodREGISTER) + rl.Allow("192.168.1.3", MethodREGISTER) + + stats := rl.GetStats() + trackedIPs := stats["tracked_ips"].(int) + if trackedIPs != 3 { + t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs) + } + + // Cleanup shouldn't remove fresh entries + rl.Cleanup() + stats = rl.GetStats() + trackedIPs = stats["tracked_ips"].(int) + if trackedIPs != 3 { + t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs) + } +} + +func TestRateLimiterStats(t *testing.T) { + logger := zap.NewNop() + rl := NewRateLimiter(logger) + + rl.SetLimit(MethodREGISTER, &MethodRateLimit{ + Method: MethodREGISTER, + MaxRequests: 10, + Window: time.Minute, + BurstSize: 5, + }) + + rl.Allow("192.168.1.100", MethodREGISTER) + + stats := rl.GetStats() + + if _, ok := stats["tracked_ips"]; !ok { + t.Error("Stats should include tracked_ips") + } + if _, ok := stats["limits"]; !ok { + t.Error("Stats should include limits") + } + + limits := stats["limits"].(map[string]interface{}) + if _, ok := limits["REGISTER"]; !ok { + t.Error("Limits should include REGISTER config") + } +} + +func TestGlobalRateLimiter(t *testing.T) { + logger := zap.NewNop() + + // Reset global rate limiter for clean test + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + rl := GetRateLimiter(logger) + if rl == nil { + t.Fatal("GetRateLimiter should return non-nil") + } + + // Should return same instance + rl2 := GetRateLimiter(logger) + if rl != rl2 { + t.Error("GetRateLimiter should return same global instance") + } + + // Should have default limits applied + registerLimit := rl.GetLimit(MethodREGISTER) + if registerLimit.MaxRequests != 10 { + t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests) + } +} + +// ============================================================================= +// SIP Parsing Function Tests +// ============================================================================= + +func TestExtractSIPMethod(t *testing.T) { + tests := []struct { + name string + data []byte + expected SIPMethod + }{ + {"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER}, + {"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE}, + {"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS}, + {"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK}, + {"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE}, + {"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL}, + {"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO}, + {"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY}, + {"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE}, + {"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE}, + {"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE}, + {"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK}, + {"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER}, + {"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH}, + {"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""}, + {"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + method := ExtractSIPMethod(tt.data) + if method != tt.expected { + t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected) + } + }) + } +} + +// Note: TestExtractTargetExtension is already defined in enumeration_test.go +// These additional tests cover edge cases for rate limiter integration + +func TestExtractTargetExtensionEdgeCases(t *testing.T) { + tests := []struct { + name string + data []byte + expected string + }{ + // Legitimate extension extractions (additional cases) + { + name: "Simple 4-digit extension", + data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"), + expected: "1001", + }, + { + name: "3-digit extension", + data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"), + expected: "200", + }, + { + name: "Alphanumeric short name", + data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"), + expected: "john", + }, + // Should NOT extract these as extensions (domain-like or too long) + { + name: "Full domain should not be extracted", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n"), + expected: "", // Contains a dot, filtered out + }, + { + name: "Too long identifier", + data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"), + expected: "", // >10 chars + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ext := ExtractTargetExtension(tt.data) + if ext != tt.expected { + t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected) + } + }) + } +} + +func TestParseUserAgent(t *testing.T) { + tests := []struct { + name string + data []byte + expected string + }{ + { + name: "Yealink phone", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + + "User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"), + expected: "Yealink SIP-T46S 66.86.0.15", + }, + { + name: "Linphone", + data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" + + "User-Agent: Linphone/4.5.0\r\n"), + expected: "Linphone/4.5.0", + }, + { + name: "No User-Agent", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.1\r\n"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ua := ParseUserAgent(tt.data) + if ua != tt.expected { + t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected) + } + }) + } +} + +func TestParseFromHeader(t *testing.T) { + tests := []struct { + name string + data []byte + expectedUser string + expectedDomain string + }{ + { + name: "Standard From header", + data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + + "From: \"Alice\" ;tag=1234\r\n"), + expectedUser: "alice", + expectedDomain: "sip.example.com", + }, + { + name: "From header without display name", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + + "From: ;tag=abc\r\n"), + expectedUser: "1001", + expectedDomain: "pbx.local", + }, + { + name: "No From header", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + + "To: \r\n"), + expectedUser: "", + expectedDomain: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, domain := ParseFromHeader(tt.data) + if user != tt.expectedUser { + t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser) + } + if domain != tt.expectedDomain { + t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain) + } + }) + } +} + +func TestParseToHeader(t *testing.T) { + tests := []struct { + name string + data []byte + expectedUser string + expectedDomain string + }{ + { + name: "Standard To header", + data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + + "To: \"Bob\" \r\n"), + expectedUser: "bob", + expectedDomain: "sip.example.com", + }, + { + name: "To header without display name", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + + "To: \r\n"), + expectedUser: "5002", + expectedDomain: "pbx.local", + }, + { + name: "No To header", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + + "From: \r\n"), + expectedUser: "", + expectedDomain: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + user, domain := ParseToHeader(tt.data) + if user != tt.expectedUser { + t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser) + } + if domain != tt.expectedDomain { + t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain) + } + }) + } +} + +func TestParseCallID(t *testing.T) { + tests := []struct { + name string + data []byte + expected string + }{ + { + name: "Standard Call-ID header", + data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + + "Call-ID: 12345-67890@192.168.1.100\r\n"), + expected: "12345-67890@192.168.1.100", + }, + { + name: "Short form Call-ID (i:)", + data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + + "i: compact-callid@host\r\n"), + expected: "compact-callid@host", + }, + { + name: "No Call-ID", + data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + + "From: \r\n"), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + callID := ParseCallID(tt.data) + if callID != tt.expected { + t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected) + } + }) + } +} + +// ============================================================================= +// Real-world Legitimate Traffic Scenarios +// ============================================================================= + +func TestLegitimatePhoneRegistration(t *testing.T) { + // Simulate a typical phone registration pattern: + // 1. Initial REGISTER (usually 401/407 challenge) + // 2. Second REGISTER with auth + // 3. Keep-alive OPTIONS + // 4. Re-REGISTER before expiry + + logger := zap.NewNop() + + // Reset global rate limiter for clean test + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + rl := GetRateLimiter(logger) + ip := "192.168.1.200" + + // Simulate pattern over time + allowedCount := 0 + + // Initial REGISTER + if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { + allowedCount++ + } + // Auth REGISTER + if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { + allowedCount++ + } + // Keep-alive OPTIONS (phones do this frequently) + for i := 0; i < 5; i++ { + if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed { + allowedCount++ + } + } + // Another REGISTER + if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { + allowedCount++ + } + + // At minimum, the phone should complete registration (2 REGISTER + OPTIONS) + if allowedCount < 5 { + t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount) + } +} + +func TestLegitimateCallFlow(t *testing.T) { + // Simulate a typical call flow: + // INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK + + logger := zap.NewNop() + + // Reset global rate limiter + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + rl := GetRateLimiter(logger) + ip := "192.168.1.201" + + // INVITE + allowed, reason := rl.Allow(ip, MethodINVITE) + if !allowed { + t.Errorf("INVITE should be allowed for call: %s", reason) + } + + // ACK (not rate limited by default) + allowed, reason = rl.Allow(ip, MethodACK) + if !allowed { + t.Errorf("ACK should be allowed for call: %s", reason) + } + + // BYE + allowed, reason = rl.Allow(ip, MethodBYE) + if !allowed { + t.Errorf("BYE should be allowed for call: %s", reason) + } +} + +func TestLegitimateSubscriptionFlow(t *testing.T) { + // Simulate presence/BLF subscription + logger := zap.NewNop() + + // Reset global rate limiter + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + rl := GetRateLimiter(logger) + ip := "192.168.1.202" + + // SUBSCRIBE for presence (phones do multiple for BLF) + allowedCount := 0 + for i := 0; i < 10; i++ { + if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed { + allowedCount++ + } + } + + // Default allows 5 burst + some refill + if allowedCount < 5 { + t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount) + } +} + +func TestLegitimateMessaging(t *testing.T) { + // Simulate SMS/messaging traffic + logger := zap.NewNop() + + // Reset global rate limiter + rateLimiterMu.Lock() + globalRateLimiter = nil + rateLimiterMu.Unlock() + + rl := GetRateLimiter(logger) + ip := "192.168.1.203" + + // MESSAGE (higher limit - 100/min default) + allowedCount := 0 + for i := 0; i < 25; i++ { + if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed { + allowedCount++ + } + } + + // Should allow most messages (20 burst default) + if allowedCount < 20 { + t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount) + } +} diff --git a/sipguardian.go b/sipguardian.go index eeeabb1..98d48cd 100644 --- a/sipguardian.go +++ b/sipguardian.go @@ -41,6 +41,11 @@ type SIPGuardian struct { BanTime caddy.Duration `json:"ban_time,omitempty"` WhitelistCIDR []string `json:"whitelist_cidr,omitempty"` + // DNS-aware whitelist configuration + WhitelistHosts []string `json:"whitelist_hosts,omitempty"` // Hostnames to resolve (A/AAAA) + WhitelistSRV []string `json:"whitelist_srv,omitempty"` // SRV records to resolve + DNSRefresh caddy.Duration `json:"dns_refresh,omitempty"` // DNS refresh interval (default: 5m) + // Webhook configuration Webhooks []WebhookConfig `json:"webhooks,omitempty"` @@ -48,7 +53,7 @@ type SIPGuardian struct { StoragePath string `json:"storage_path,omitempty"` // GeoIP configuration - GeoIPPath string `json:"geoip_path,omitempty"` + GeoIPPath string `json:"geoip_path,omitempty"` BlockedCountries []string `json:"blocked_countries,omitempty"` AllowedCountries []string `json:"allowed_countries,omitempty"` @@ -63,6 +68,7 @@ type SIPGuardian struct { bannedIPs map[string]*BanEntry failureCounts map[string]*failureTracker whitelistNets []*net.IPNet + dnsWhitelist *DNSWhitelist mu sync.RWMutex storage *Storage geoIP *GeoIPLookup @@ -155,6 +161,34 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { } } + // Initialize DNS whitelist if configured + if len(g.WhitelistHosts) > 0 || len(g.WhitelistSRV) > 0 { + refreshInterval := 5 * time.Minute + if g.DNSRefresh > 0 { + refreshInterval = time.Duration(g.DNSRefresh) + } + + g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{ + Hostnames: g.WhitelistHosts, + SRVRecords: g.WhitelistSRV, + RefreshInterval: refreshInterval, + AllowStale: true, + ResolveTimeout: 10 * time.Second, + }, g.logger) + + if err := g.dnsWhitelist.Start(); err != nil { + g.logger.Warn("Failed to initialize DNS whitelist", + zap.Error(err), + ) + } else { + g.logger.Info("DNS whitelist initialized", + zap.Int("hostnames", len(g.WhitelistHosts)), + zap.Int("srv_records", len(g.WhitelistSRV)), + zap.Duration("refresh_interval", refreshInterval), + ) + } + } + // Initialize enumeration detection with config if specified if g.Enumeration != nil { SetEnumerationConfig(*g.Enumeration) @@ -216,12 +250,14 @@ func (g *SIPGuardian) loadBansFromStorage() error { return nil } -// IsWhitelisted checks if an IP is in the whitelist +// IsWhitelisted checks if an IP is in the whitelist (CIDR or DNS-based) func (g *SIPGuardian) IsWhitelisted(ip string) bool { parsedIP := net.ParseIP(ip) if parsedIP == nil { return false } + + // Check CIDR-based whitelist for _, network := range g.whitelistNets { if network.Contains(parsedIP) { if enableMetrics { @@ -230,6 +266,19 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool { return true } } + + // Check DNS-based whitelist + if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) { + if enableMetrics { + RecordWhitelistedConnection() + } + g.logger.Debug("IP whitelisted via DNS", + zap.String("ip", ip), + zap.String("source", g.dnsWhitelist.GetSource(ip).Source), + ) + return true + } + return false } @@ -448,10 +497,34 @@ func (g *SIPGuardian) GetStats() map[string]interface{} { } } - return map[string]interface{}{ - "active_bans": activeBans, - "tracked_failures": len(g.failureCounts), - "whitelist_count": len(g.whitelistNets), + stats := map[string]interface{}{ + "active_bans": activeBans, + "tracked_failures": len(g.failureCounts), + "whitelist_cidr": len(g.whitelistNets), + "whitelist_hosts": len(g.WhitelistHosts), + "whitelist_srv": len(g.WhitelistSRV), + } + + // Add DNS whitelist stats if available + if g.dnsWhitelist != nil { + stats["dns_whitelist"] = g.dnsWhitelist.Stats() + } + + return stats +} + +// GetDNSWhitelistEntries returns all resolved DNS whitelist entries +func (g *SIPGuardian) GetDNSWhitelistEntries() []ResolvedEntry { + if g.dnsWhitelist == nil { + return nil + } + return g.dnsWhitelist.GetResolvedIPs() +} + +// RefreshDNSWhitelist forces an immediate refresh of DNS whitelist entries +func (g *SIPGuardian) RefreshDNSWhitelist() { + if g.dnsWhitelist != nil { + g.dnsWhitelist.ForceRefresh() } } @@ -500,8 +573,15 @@ func (g *SIPGuardian) cleanup() { // max_failures 5 // find_time 10m // ban_time 1h +// +// # IP/CIDR whitelist (static) // whitelist 10.0.0.0/8 192.168.0.0/16 // +// # DNS-aware whitelist (dynamic, auto-refreshed) +// whitelist_hosts pbx.example.com trunk.provider.com +// whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net +// dns_refresh 5m # How often to refresh DNS (default: 5m) +// // # Persistent storage // storage /var/lib/sip-guardian/guardian.db // @@ -552,10 +632,34 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { g.BanTime = caddy.Duration(dur) case "whitelist": + // Legacy: CIDR-only whitelist for d.NextArg() { g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val()) } + case "whitelist_hosts": + // DNS A/AAAA record whitelist (hostnames resolved to IPs) + for d.NextArg() { + g.WhitelistHosts = append(g.WhitelistHosts, d.Val()) + } + + case "whitelist_srv": + // DNS SRV record whitelist (e.g., _sip._udp.provider.com) + for d.NextArg() { + g.WhitelistSRV = append(g.WhitelistSRV, d.Val()) + } + + case "dns_refresh": + // Interval for refreshing DNS-based whitelist entries + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid dns_refresh: %v", err) + } + g.DNSRefresh = caddy.Duration(dur) + case "storage": if !d.NextArg() { return d.ArgErr()