Add DNS-aware whitelisting feature

Support for whitelisting SIP trunks and providers by hostname or SRV
record with automatic IP resolution and periodic refresh.

Features:
- Hostname resolution via A/AAAA records
- SRV record resolution (e.g., _sip._udp.provider.com)
- Configurable refresh interval (default 5m)
- Stale entry handling when DNS fails
- Admin API endpoints for DNS whitelist management
- Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh

This allows whitelisting by provider name rather than tracking
constantly-changing IP addresses.
This commit is contained in:
Ryan Malloy 2025-12-08 00:46:43 -07:00
parent 46a47ce2c6
commit 5cf34eb3c0
8 changed files with 2383 additions and 11 deletions

View File

@ -3,7 +3,7 @@
[![Go Version](https://img.shields.io/badge/Go-1.25+-00ADD8?style=flat&logo=go)](https://go.dev/)
[![Caddy](https://img.shields.io/badge/Caddy-2.10+-22b638?style=flat&logo=caddy)](https://caddyserver.com/)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)
[![Tests](https://img.shields.io/badge/Tests-60%20passing-success)](https://git.supported.systems/rsp2k/caddy-sip-guardian)
[![Tests](https://img.shields.io/badge/Tests-196%20passing-success)](https://git.supported.systems/rsp2k/caddy-sip-guardian)
> **A comprehensive Caddy module providing SIP-aware security at Layer 4.**
> Protects your VoIP infrastructure with intelligent rate limiting, attack detection, message validation, and topology hiding.
@ -31,7 +31,8 @@ Traditional SIP security (like fail2ban) parses logs *after* attacks reach your
- **Intelligent Rate Limiting** — Per-method token bucket rate limiting with burst support
- **Automatic Banning** — Ban IPs that exceed failure thresholds
- **Attack Detection** — Detect common SIP scanning tools (SIPVicious, friendly-scanner, etc.)
- **CIDR Whitelisting** — Whitelist trusted networks
- **CIDR Whitelisting** — Whitelist trusted networks by IP range
- **DNS-aware Whitelisting** — Whitelist SIP trunks by hostname or SRV record with auto-refresh
- **GeoIP Blocking** — Block traffic by country using MaxMind databases
### 🔍 Extension Enumeration Detection
@ -298,6 +299,46 @@ enumeration {
---
### DNS-aware Whitelisting
Whitelist SIP trunks and providers by hostname or SRV record. IPs are automatically resolved and refreshed:
```caddyfile
sip_guardian {
# Static CIDR whitelist (always available)
whitelist 10.0.0.0/8 192.168.0.0/16
# DNS-aware whitelist - resolved to IPs automatically
whitelist_hosts pbx.example.com trunk.sipcarrier.net
whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
dns_refresh 5m # How often to refresh DNS lookups (default: 5m)
}
```
**Why DNS-aware whitelisting?**
| Static IP Whitelisting | DNS-aware Whitelisting |
|------------------------|------------------------|
| Breaks when provider changes IPs | Auto-updates when IPs change |
| Must manually track carrier IPs | Just use their SRV record |
| Fails silently on changes | Logs refresh events |
**SRV Record Support:**
SIP trunks commonly use SRV records for load balancing and failover. SIP Guardian resolves the full chain:
```
_sip._udp.carrier.com → sip1.carrier.com, sip2.carrier.com → 203.0.113.10, 203.0.113.11
```
**Admin API Endpoints:**
| Method | Endpoint | Description |
|--------|----------|-------------|
| `GET` | `/api/sip-guardian/dns-whitelist` | List all resolved DNS entries |
| `POST` | `/api/sip-guardian/dns-whitelist/refresh` | Force immediate DNS refresh |
---
### SIP Message Validation
Enforces RFC 3261 compliance and blocks malformed/malicious packets:

View File

@ -57,6 +57,10 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next ca
return h.handleBans(w, r)
case strings.HasSuffix(path, "/stats"):
return h.handleStats(w, r)
case strings.HasSuffix(path, "/dns-whitelist"):
return h.handleDNSWhitelist(w, r)
case strings.HasSuffix(path, "/dns-whitelist/refresh"):
return h.handleDNSWhitelistRefresh(w, r)
case strings.Contains(path, "/unban/"):
return h.handleUnban(w, r, path)
case strings.Contains(path, "/ban/"):
@ -131,6 +135,50 @@ func (h *AdminHandler) handleUnban(w http.ResponseWriter, r *http.Request, path
return nil
}
// handleDNSWhitelist returns DNS whitelist entries and stats
func (h *AdminHandler) handleDNSWhitelist(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return nil
}
entries := h.guardian.GetDNSWhitelistEntries()
stats := h.guardian.GetStats()
response := map[string]interface{}{
"entries": entries,
"count": len(entries),
}
// Add DNS-specific stats if available
if dnsStats, ok := stats["dns_whitelist"]; ok {
response["stats"] = dnsStats
}
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(response)
}
// handleDNSWhitelistRefresh forces an immediate DNS refresh
func (h *AdminHandler) handleDNSWhitelistRefresh(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return nil
}
h.guardian.RefreshDNSWhitelist()
// Get updated entries
entries := h.guardian.GetDNSWhitelistEntries()
w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(map[string]interface{}{
"success": true,
"message": "DNS whitelist refreshed",
"count": len(entries),
})
}
// handleBan manually adds an IP to the ban list
func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path string) error {
if r.Method != http.MethodPost {

382
dns_whitelist.go Normal file
View File

@ -0,0 +1,382 @@
// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian.
// This allows whitelisting by hostname, A/AAAA records, and SRV records.
package sipguardian
import (
"context"
"net"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// DNSWhitelistConfig holds configuration for DNS-based whitelisting
type DNSWhitelistConfig struct {
// Hostnames to resolve and whitelist (A/AAAA records)
Hostnames []string `json:"hostnames,omitempty"`
// SRV records to resolve (e.g., "_sip._udp.provider.com")
// Will resolve SRV -> hostnames -> IPs
SRVRecords []string `json:"srv_records,omitempty"`
// RefreshInterval for DNS lookups (default: 5m)
RefreshInterval time.Duration `json:"refresh_interval,omitempty"`
// AllowStale allows using cached IPs if DNS refresh fails (default: true)
AllowStale bool `json:"allow_stale,omitempty"`
// ResolveTimeout for individual DNS queries (default: 10s)
ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"`
}
// DNSWhitelist manages DNS-based IP whitelisting with automatic refresh
type DNSWhitelist struct {
config DNSWhitelistConfig
logger *zap.Logger
// Resolved IPs with their source info
resolvedIPs map[string]*ResolvedEntry
mu sync.RWMutex
// For graceful shutdown
stopCh chan struct{}
wg sync.WaitGroup
}
// ResolvedEntry tracks where an IP came from
type ResolvedEntry struct {
IP string `json:"ip"`
Source string `json:"source"` // hostname or SRV record
SourceType string `json:"source_type"` // "hostname", "srv"
ResolvedAt time.Time `json:"resolved_at"`
ExpiresAt time.Time `json:"expires_at"`
TTL int `json:"ttl"` // DNS TTL in seconds
}
// NewDNSWhitelist creates a new DNS whitelist manager
func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist {
// Set defaults
if config.RefreshInterval == 0 {
config.RefreshInterval = 5 * time.Minute
}
if config.ResolveTimeout == 0 {
config.ResolveTimeout = 10 * time.Second
}
if !config.AllowStale {
config.AllowStale = true // Default to allowing stale on DNS failure
}
return &DNSWhitelist{
config: config,
logger: logger,
resolvedIPs: make(map[string]*ResolvedEntry),
stopCh: make(chan struct{}),
}
}
// Start begins the DNS refresh loop
func (d *DNSWhitelist) Start() error {
// Do initial resolution
d.refreshAll()
// Start background refresh
d.wg.Add(1)
go d.refreshLoop()
d.logger.Info("DNS whitelist started",
zap.Int("hostnames", len(d.config.Hostnames)),
zap.Int("srv_records", len(d.config.SRVRecords)),
zap.Duration("refresh_interval", d.config.RefreshInterval),
)
return nil
}
// Stop stops the DNS refresh loop
func (d *DNSWhitelist) Stop() {
close(d.stopCh)
d.wg.Wait()
}
// Contains checks if an IP is in the DNS whitelist
func (d *DNSWhitelist) Contains(ip string) bool {
d.mu.RLock()
defer d.mu.RUnlock()
entry, exists := d.resolvedIPs[ip]
if !exists {
return false
}
// Check if entry is still valid
if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale {
return false
}
return true
}
// GetSource returns the source info for a whitelisted IP
func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry {
d.mu.RLock()
defer d.mu.RUnlock()
if entry, exists := d.resolvedIPs[ip]; exists {
// Return copy to avoid race conditions
entryCopy := *entry
return &entryCopy
}
return nil
}
// GetResolvedIPs returns all currently resolved IPs
func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry {
d.mu.RLock()
defer d.mu.RUnlock()
entries := make([]ResolvedEntry, 0, len(d.resolvedIPs))
for _, entry := range d.resolvedIPs {
entries = append(entries, *entry)
}
return entries
}
// refreshLoop periodically refreshes DNS entries
func (d *DNSWhitelist) refreshLoop() {
defer d.wg.Done()
ticker := time.NewTicker(d.config.RefreshInterval)
defer ticker.Stop()
for {
select {
case <-d.stopCh:
return
case <-ticker.C:
d.refreshAll()
}
}
}
// refreshAll resolves all configured hostnames and SRV records
func (d *DNSWhitelist) refreshAll() {
ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1))
defer cancel()
newResolved := make(map[string]*ResolvedEntry)
now := time.Now()
defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin
// Resolve hostnames (A/AAAA records)
for _, hostname := range d.config.Hostnames {
ips, err := d.resolveHostname(ctx, hostname)
if err != nil {
d.logger.Warn("Failed to resolve hostname",
zap.String("hostname", hostname),
zap.Error(err),
)
// Keep existing entries if AllowStale
if d.config.AllowStale {
d.copyExistingEntries(newResolved, hostname, "hostname")
}
continue
}
for _, ip := range ips {
newResolved[ip] = &ResolvedEntry{
IP: ip,
Source: hostname,
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: defaultExpiry,
}
}
d.logger.Debug("Resolved hostname",
zap.String("hostname", hostname),
zap.Strings("ips", ips),
)
}
// Resolve SRV records
for _, srv := range d.config.SRVRecords {
ips, targets, err := d.resolveSRV(ctx, srv)
if err != nil {
d.logger.Warn("Failed to resolve SRV record",
zap.String("srv", srv),
zap.Error(err),
)
// Keep existing entries if AllowStale
if d.config.AllowStale {
d.copyExistingEntries(newResolved, srv, "srv")
}
continue
}
for _, ip := range ips {
newResolved[ip] = &ResolvedEntry{
IP: ip,
Source: srv,
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: defaultExpiry,
}
}
d.logger.Debug("Resolved SRV record",
zap.String("srv", srv),
zap.Strings("targets", targets),
zap.Strings("ips", ips),
)
}
// Update the map atomically
d.mu.Lock()
d.resolvedIPs = newResolved
d.mu.Unlock()
d.logger.Info("DNS whitelist refreshed",
zap.Int("total_ips", len(newResolved)),
)
}
// copyExistingEntries copies existing entries for a source to the new map
func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) {
d.mu.RLock()
defer d.mu.RUnlock()
for ip, entry := range d.resolvedIPs {
if entry.Source == source && entry.SourceType == sourceType {
newMap[ip] = entry
}
}
}
// resolveHostname resolves a hostname to IP addresses
func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) {
// Handle case where hostname is already an IP
if ip := net.ParseIP(hostname); ip != nil {
return []string{hostname}, nil
}
resolver := net.DefaultResolver
addrs, err := resolver.LookupIPAddr(ctx, hostname)
if err != nil {
return nil, err
}
ips := make([]string, 0, len(addrs))
for _, addr := range addrs {
ips = append(ips, addr.IP.String())
}
return ips, nil
}
// resolveSRV resolves an SRV record and all its targets
func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) {
// Parse SRV record format: _service._proto.name or just name
// Example: _sip._udp.provider.com or _sip._tcp.provider.com
var service, proto, name string
parts := strings.Split(srvRecord, ".")
if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") {
service = strings.TrimPrefix(parts[0], "_")
proto = strings.TrimPrefix(parts[1], "_")
name = strings.Join(parts[2:], ".")
} else {
// Assume it's a plain domain, try common SIP SRV patterns
// Try _sip._udp first, then _sip._tcp
service = "sip"
proto = "udp"
name = srvRecord
}
resolver := net.DefaultResolver
// Try to resolve SRV
_, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name)
if err != nil {
// If UDP fails, try TCP
if proto == "udp" {
_, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name)
if err != nil {
// Fall back to A record lookup on the original name
directIPs, aErr := d.resolveHostname(ctx, srvRecord)
if aErr != nil {
return nil, nil, err // Return original SRV error
}
return directIPs, []string{srvRecord}, nil
}
} else {
return nil, nil, err
}
}
// Resolve each SRV target to IPs
seenIPs := make(map[string]bool)
for _, srv := range srvRecords {
target := strings.TrimSuffix(srv.Target, ".")
targets = append(targets, target)
targetIPs, err := d.resolveHostname(ctx, target)
if err != nil {
d.logger.Warn("Failed to resolve SRV target",
zap.String("srv", srvRecord),
zap.String("target", target),
zap.Error(err),
)
continue
}
for _, ip := range targetIPs {
if !seenIPs[ip] {
seenIPs[ip] = true
ips = append(ips, ip)
}
}
}
return ips, targets, nil
}
// ForceRefresh triggers an immediate refresh of DNS entries
func (d *DNSWhitelist) ForceRefresh() {
d.refreshAll()
}
// Stats returns statistics about the DNS whitelist
func (d *DNSWhitelist) Stats() map[string]interface{} {
d.mu.RLock()
defer d.mu.RUnlock()
hostnameIPs := 0
srvIPs := 0
staleCount := 0
now := time.Now()
for _, entry := range d.resolvedIPs {
switch entry.SourceType {
case "hostname":
hostnameIPs++
case "srv":
srvIPs++
}
if now.After(entry.ExpiresAt) {
staleCount++
}
}
return map[string]interface{}{
"total_ips": len(d.resolvedIPs),
"hostname_ips": hostnameIPs,
"srv_ips": srvIPs,
"stale_count": staleCount,
"configured_hosts": len(d.config.Hostnames),
"configured_srv": len(d.config.SRVRecords),
"refresh_interval": d.config.RefreshInterval.String(),
"allow_stale": d.config.AllowStale,
}
}

500
dns_whitelist_test.go Normal file
View File

@ -0,0 +1,500 @@
package sipguardian
import (
"net"
"testing"
"time"
"go.uber.org/zap"
)
// TestDNSWhitelistConfig tests configuration defaults and validation
func TestDNSWhitelistConfig(t *testing.T) {
logger := zap.NewNop()
tests := []struct {
name string
config DNSWhitelistConfig
expectedRefresh time.Duration
expectedTimeout time.Duration
}{
{
name: "empty config uses defaults",
config: DNSWhitelistConfig{},
expectedRefresh: 5 * time.Minute,
expectedTimeout: 10 * time.Second,
},
{
name: "custom refresh interval",
config: DNSWhitelistConfig{
RefreshInterval: 10 * time.Minute,
},
expectedRefresh: 10 * time.Minute,
expectedTimeout: 10 * time.Second,
},
{
name: "custom timeout",
config: DNSWhitelistConfig{
ResolveTimeout: 30 * time.Second,
},
expectedRefresh: 5 * time.Minute,
expectedTimeout: 30 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wl := NewDNSWhitelist(tt.config, logger)
if wl.config.RefreshInterval != tt.expectedRefresh {
t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh)
}
if wl.config.ResolveTimeout != tt.expectedTimeout {
t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout)
}
if !wl.config.AllowStale {
t.Error("AllowStale should default to true")
}
})
}
}
// TestDNSWhitelistContains tests the IP lookup functionality
func TestDNSWhitelistContains(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Manually add some entries for testing
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
tests := []struct {
ip string
expected bool
}{
{"192.168.1.100", true},
{"10.0.0.50", true},
{"8.8.8.8", false},
{"192.168.1.101", false},
{"", false},
}
for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) {
if got := wl.Contains(tt.ip); got != tt.expected {
t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected)
}
})
}
}
// TestDNSWhitelistExpiredEntries tests handling of expired entries
func TestDNSWhitelistExpiredEntries(t *testing.T) {
logger := zap.NewNop()
t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) {
wl := NewDNSWhitelist(DNSWhitelistConfig{
AllowStale: false,
}, logger)
// Override the default
wl.config.AllowStale = false
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now.Add(-1 * time.Hour),
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
}
if wl.Contains("192.168.1.100") {
t.Error("Expired entry should not match when AllowStale=false")
}
})
t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) {
wl := NewDNSWhitelist(DNSWhitelistConfig{
AllowStale: true,
}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "test.example.com",
SourceType: "hostname",
ResolvedAt: now.Add(-1 * time.Hour),
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
}
if !wl.Contains("192.168.1.100") {
t.Error("Expired entry should match when AllowStale=true")
}
})
}
// TestDNSWhitelistGetSource tests source info retrieval
func TestDNSWhitelistGetSource(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "pbx.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
t.Run("existing IP returns source", func(t *testing.T) {
source := wl.GetSource("192.168.1.100")
if source == nil {
t.Fatal("GetSource returned nil for existing IP")
}
if source.Source != "pbx.example.com" {
t.Errorf("Source = %s, want pbx.example.com", source.Source)
}
if source.SourceType != "hostname" {
t.Errorf("SourceType = %s, want hostname", source.SourceType)
}
})
t.Run("non-existent IP returns nil", func(t *testing.T) {
source := wl.GetSource("8.8.8.8")
if source != nil {
t.Error("GetSource should return nil for non-existent IP")
}
})
}
// TestDNSWhitelistGetResolvedIPs tests listing all entries
func TestDNSWhitelistGetResolvedIPs(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Empty whitelist
entries := wl.GetResolvedIPs()
if len(entries) != 0 {
t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries))
}
// Add entries
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "host1.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
entries = wl.GetResolvedIPs()
if len(entries) != 2 {
t.Errorf("Expected 2 entries, got %d", len(entries))
}
}
// TestDNSWhitelistStats tests statistics reporting
func TestDNSWhitelistStats(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"host1.example.com", "host2.example.com"},
SRVRecords: []string{"_sip._udp.provider.com"},
RefreshInterval: 10 * time.Minute,
}, logger)
now := time.Now()
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
IP: "192.168.1.100",
Source: "host1.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{
IP: "192.168.1.101",
Source: "host2.example.com",
SourceType: "hostname",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
IP: "10.0.0.50",
Source: "_sip._udp.provider.com",
SourceType: "srv",
ResolvedAt: now,
ExpiresAt: now.Add(10 * time.Minute),
}
stats := wl.Stats()
if stats["total_ips"] != 3 {
t.Errorf("total_ips = %v, want 3", stats["total_ips"])
}
if stats["hostname_ips"] != 2 {
t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"])
}
if stats["srv_ips"] != 1 {
t.Errorf("srv_ips = %v, want 1", stats["srv_ips"])
}
if stats["configured_hosts"] != 2 {
t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"])
}
if stats["configured_srv"] != 1 {
t.Errorf("configured_srv = %v, want 1", stats["configured_srv"])
}
}
// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs
func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Test that direct IP addresses are handled correctly
tests := []struct {
input string
expected string
}{
{"192.168.1.100", "192.168.1.100"},
{"10.0.0.1", "10.0.0.1"},
{"::1", "::1"},
{"2001:db8::1", "2001:db8::1"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
ips, err := wl.resolveHostname(nil, tt.input)
if err != nil {
t.Fatalf("resolveHostname(%s) error: %v", tt.input, err)
}
if len(ips) != 1 || ips[0] != tt.expected {
t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected)
}
})
}
}
// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test)
// This test requires network connectivity
func TestDNSWhitelistRealDNS(t *testing.T) {
if testing.Short() {
t.Skip("Skipping DNS integration test in short mode")
}
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"localhost"},
RefreshInterval: 1 * time.Minute,
ResolveTimeout: 5 * time.Second,
}, logger)
// Start and let it do initial resolution
if err := wl.Start(); err != nil {
t.Fatalf("Failed to start DNS whitelist: %v", err)
}
defer wl.Stop()
// Give it a moment to resolve
time.Sleep(100 * time.Millisecond)
// localhost should resolve to 127.0.0.1 or ::1
if !wl.Contains("127.0.0.1") && !wl.Contains("::1") {
entries := wl.GetResolvedIPs()
t.Errorf("Expected localhost to resolve, got entries: %+v", entries)
}
}
// TestDNSWhitelistStartStop tests the lifecycle management
func TestDNSWhitelistStartStop(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS
RefreshInterval: 100 * time.Millisecond,
}, logger)
// Start
if err := wl.Start(); err != nil {
t.Fatalf("Start() error: %v", err)
}
// Should have the IP immediately
if !wl.Contains("127.0.0.1") {
t.Error("Should contain 127.0.0.1 after start")
}
// Stop should complete without hanging
done := make(chan struct{})
go func() {
wl.Stop()
close(done)
}()
select {
case <-done:
// Good, stopped cleanly
case <-time.After(2 * time.Second):
t.Error("Stop() took too long")
}
}
// TestDNSWhitelistForceRefresh tests the manual refresh functionality
func TestDNSWhitelistForceRefresh(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1"},
RefreshInterval: 1 * time.Hour, // Long interval
}, logger)
// Don't start the refresh loop, just manually refresh
wl.ForceRefresh()
if !wl.Contains("127.0.0.1") {
t.Error("ForceRefresh should resolve IPs")
}
}
// TestDNSWhitelistConcurrency tests concurrent access to the whitelist
func TestDNSWhitelistConcurrency(t *testing.T) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"},
RefreshInterval: 50 * time.Millisecond,
}, logger)
if err := wl.Start(); err != nil {
t.Fatalf("Start() error: %v", err)
}
defer wl.Stop()
// Run concurrent reads while refresh is happening
done := make(chan struct{})
go func() {
for i := 0; i < 1000; i++ {
wl.Contains("127.0.0.1")
wl.GetResolvedIPs()
wl.Stats()
}
close(done)
}()
select {
case <-done:
// Good
case <-time.After(5 * time.Second):
t.Error("Concurrent operations took too long")
}
}
// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian
func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) {
// Create a SIPGuardian with DNS whitelist config
g := &SIPGuardian{
WhitelistHosts: []string{"127.0.0.1"},
WhitelistSRV: []string{},
}
// Initialize maps
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
g.logger = zap.NewNop()
// Manually create DNS whitelist (normally done in Provision)
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
}, g.logger)
g.dnsWhitelist.ForceRefresh()
// Test IsWhitelisted
if !g.IsWhitelisted("127.0.0.1") {
t.Error("127.0.0.1 should be whitelisted via DNS")
}
if g.IsWhitelisted("8.8.8.8") {
t.Error("8.8.8.8 should NOT be whitelisted")
}
}
// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together
func TestSIPGuardianMixedWhitelist(t *testing.T) {
g := &SIPGuardian{
WhitelistCIDR: []string{"10.0.0.0/8"},
WhitelistHosts: []string{"127.0.0.1"},
}
// Initialize
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
g.logger = zap.NewNop()
// Parse CIDR whitelist (normally done in Provision)
for _, cidr := range g.WhitelistCIDR {
_, network, err := net.ParseCIDR(cidr)
if err == nil {
g.whitelistNets = append(g.whitelistNets, network)
}
}
// Set up DNS whitelist
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
}, g.logger)
g.dnsWhitelist.ForceRefresh()
// Test CIDR whitelist
if !g.IsWhitelisted("10.1.2.3") {
t.Error("10.1.2.3 should be whitelisted via CIDR")
}
// Test DNS whitelist
if !g.IsWhitelisted("127.0.0.1") {
t.Error("127.0.0.1 should be whitelisted via DNS")
}
// Test non-whitelisted
if g.IsWhitelisted("8.8.8.8") {
t.Error("8.8.8.8 should NOT be whitelisted")
}
}
// BenchmarkDNSWhitelistContains benchmarks lookup performance
func BenchmarkDNSWhitelistContains(b *testing.B) {
logger := zap.NewNop()
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
// Add 1000 entries
now := time.Now()
for i := 0; i < 1000; i++ {
ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256))
wl.resolvedIPs[ip] = &ResolvedEntry{
IP: ip,
Source: "test.example.com",
ExpiresAt: now.Add(10 * time.Minute),
}
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
wl.Contains("192.168.0.100")
}
}

View File

@ -322,6 +322,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
}
// suspiciousPatternDefs defines patterns and their names for detection
// IMPORTANT: Patterns must be specific enough to avoid false positives on legitimate traffic
var suspiciousPatternDefs = []struct {
name string
pattern string
@ -331,12 +332,14 @@ var suspiciousPatternDefs = []struct {
{"sipcli", "sipcli"},
{"sip-scan", "sip-scan"},
{"voipbuster", "voipbuster"},
{"asterisk-pbx-scanner", "asterisk pbx"},
// Note: "asterisk pbx scanner" pattern removed - too broad, catches legitimate Asterisk PBX systems
// The original pattern "asterisk pbx" would match "User-Agent: Asterisk PBX 18.0" which is legitimate
{"sipsak", "sipsak"},
{"sundayddr", "sundayddr"},
{"iwar", "iwar"},
{"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood
{"zoiper-spoof", "user-agent: zoiper"},
// Note: "cseq: 1 options" pattern REMOVED - too broad, catches ANY first OPTIONS request
// OPTIONS with CSeq 1 is completely normal - it's the first OPTIONS from any client
// Use rate limiting for OPTIONS flood detection instead
{"test-extension-100", "sip:100@"},
{"test-extension-1000", "sip:1000@"},
{"null-user", "sip:@"},

596
l4handler_test.go Normal file
View File

@ -0,0 +1,596 @@
package sipguardian
import (
"bytes"
"regexp"
"strings"
"testing"
)
// =============================================================================
// SIP Matcher Tests - Verifying SIP traffic is correctly identified
// =============================================================================
// provisionMatcherForTest creates a SIPMatcher with default methods without requiring Caddy context
func provisionMatcherForTest(methods []string) *SIPMatcher {
if len(methods) == 0 {
methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
}
pattern := "^(" + strings.Join(methods, "|") + ") sip:"
return &SIPMatcher{
Methods: methods,
methodRegex: regexp.MustCompile("(?i)" + pattern),
}
}
func TestSIPMethodPatternMatching(t *testing.T) {
// Create a provisioned matcher using our test helper
m := provisionMatcherForTest(nil)
tests := []struct {
name string
data []byte
expected bool
}{
// Legitimate SIP requests - MUST match
{
name: "REGISTER request",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "INVITE request",
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "OPTIONS request",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "ACK request",
data: []byte("ACK sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "BYE request",
data: []byte("BYE sip:alice@192.168.1.100 SIP/2.0\r\n"),
expected: true,
},
{
name: "CANCEL request",
data: []byte("CANCEL sip:bob@pbx.local SIP/2.0\r\n"),
expected: true,
},
{
name: "INFO request",
data: []byte("INFO sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "NOTIFY request",
data: []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "SUBSCRIBE request",
data: []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "MESSAGE request",
data: []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
// Case insensitivity
{
name: "lowercase register",
data: []byte("register sip:example.com SIP/2.0\r\n"),
expected: true,
},
{
name: "mixed case INVITE",
data: []byte("Invite sip:alice@example.com SIP/2.0\r\n"),
expected: true,
},
// SIP responses
{
name: "SIP 200 OK response",
data: []byte("SIP/2.0 200 OK\r\n"),
expected: true,
},
{
name: "SIP 100 Trying response",
data: []byte("SIP/2.0 100 Trying\r\n"),
expected: true,
},
{
name: "SIP 180 Ringing response",
data: []byte("SIP/2.0 180 Ringing\r\n"),
expected: true,
},
{
name: "SIP 401 Unauthorized response",
data: []byte("SIP/2.0 401 Unauthorized\r\n"),
expected: true,
},
{
name: "SIP 486 Busy Here response",
data: []byte("SIP/2.0 486 Busy Here\r\n"),
expected: true,
},
// Non-SIP traffic - MUST NOT match (should be passed through or rejected elsewhere)
{
name: "HTTP GET request",
data: []byte("GET / HTTP/1.1\r\n"),
expected: false,
},
{
name: "HTTP POST request",
data: []byte("POST /api HTTP/1.1\r\n"),
expected: false,
},
{
name: "SMTP EHLO",
data: []byte("EHLO mail.example.com\r\n"),
expected: false,
},
{
name: "random binary data",
data: []byte{0x00, 0x01, 0x02, 0x03, 0x04},
expected: false,
},
{
name: "RTP-like packet",
data: []byte{0x80, 0x00, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matches := m.methodRegex.Match(tt.data) || bytes.HasPrefix(tt.data, []byte("SIP/2.0"))
if matches != tt.expected {
t.Errorf("SIP pattern match for %q: got %v, want %v", tt.name, matches, tt.expected)
}
})
}
}
// TestMatcherDefaultMethods verifies the matcher provisions with correct default methods
func TestMatcherDefaultMethods(t *testing.T) {
m := provisionMatcherForTest(nil)
expectedMethods := []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
if len(m.Methods) != len(expectedMethods) {
t.Errorf("Default methods count: got %d, want %d", len(m.Methods), len(expectedMethods))
}
for _, method := range expectedMethods {
found := false
for _, m := range m.Methods {
if m == method {
found = true
break
}
}
if !found {
t.Errorf("Expected method %s not in default methods", method)
}
}
}
// TestMatcherCustomMethods verifies custom method configuration works
func TestMatcherCustomMethods(t *testing.T) {
m := provisionMatcherForTest([]string{"REGISTER", "INVITE"})
// Should match REGISTER
if !m.methodRegex.Match([]byte("REGISTER sip:example.com SIP/2.0\r\n")) {
t.Error("Should match REGISTER when configured")
}
// Should match INVITE
if !m.methodRegex.Match([]byte("INVITE sip:alice@example.com SIP/2.0\r\n")) {
t.Error("Should match INVITE when configured")
}
// Should NOT match OPTIONS (not in our custom list)
if m.methodRegex.Match([]byte("OPTIONS sip:example.com SIP/2.0\r\n")) {
t.Error("Should NOT match OPTIONS when not in custom methods")
}
}
// =============================================================================
// Suspicious Pattern Detection Tests
// =============================================================================
func TestDetectSuspiciousPattern(t *testing.T) {
tests := []struct {
name string
data []byte
expectDetection bool
expectedPattern string
}{
// Known attack tools - MUST be detected
{
name: "SIPVicious User-Agent",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: friendly-scanner\r\n"),
expectDetection: true,
expectedPattern: "friendly-scanner",
},
{
name: "SIPVicious lowercase",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipvicious\r\n"),
expectDetection: true,
expectedPattern: "sipvicious",
},
{
name: "sipcli scanner",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sipcli/1.0\r\n"),
expectDetection: true,
expectedPattern: "sipcli",
},
{
name: "sipsak tool",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipsak 0.9.7\r\n"),
expectDetection: true,
expectedPattern: "sipsak",
},
{
name: "VoIPBuster",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: voipbuster\r\n"),
expectDetection: true,
expectedPattern: "voipbuster",
},
{
name: "sundayddr scanner",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sundayddr\r\n"),
expectDetection: true,
expectedPattern: "sundayddr",
},
{
name: "iwar dialer",
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: iwar/0.1\r\n"),
expectDetection: true,
expectedPattern: "iwar",
},
// Common enumeration patterns
{
name: "test extension 100",
data: []byte("REGISTER sip:100@example.com SIP/2.0\r\n"),
expectDetection: true,
expectedPattern: "test-extension-100",
},
{
name: "test extension 1000",
data: []byte("REGISTER sip:1000@example.com SIP/2.0\r\n"),
expectDetection: true,
expectedPattern: "test-extension-1000",
},
{
name: "null user probe",
data: []byte("REGISTER sip:@example.com SIP/2.0\r\n"),
expectDetection: true,
expectedPattern: "null-user",
},
{
name: "anonymous caller",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nFrom: <sip:anonymous@anonymous.invalid>\r\n"),
expectDetection: true,
expectedPattern: "anonymous",
},
// LEGITIMATE traffic - MUST NOT be detected as suspicious
{
name: "Zoiper softphone",
// Zoiper is a legitimate softphone - pattern removed to avoid false positives
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Zoiper rv2.0.18\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Linphone client",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Linphone/4.5.0\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Asterisk PBX",
// The "asterisk pbx" pattern was removed as it caused false positives
// Legitimate Asterisk servers now pass through correctly
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: Asterisk PBX 18.0\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "FreeSWITCH",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nUser-Agent: FreeSWITCH\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Grandstream phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Grandstream GXP2170 1.0.11.10\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Polycom phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: PolycomVVX-VVX_410-UA/5.9.3\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Yealink phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Cisco phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Cisco-SIPIPCommunicator/9.1.1\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Avaya phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Avaya one-X Communicator\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "3CX Softphone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: 3CXPhone 6.0\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "Twilio gateway",
data: []byte("INVITE sip:+15551234567@example.com SIP/2.0\r\nUser-Agent: twilio-client/2.0\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "regular extension 5001",
data: []byte("REGISTER sip:5001@example.com SIP/2.0\r\nUser-Agent: Linphone\r\n"),
expectDetection: false,
expectedPattern: "",
},
{
name: "regular extension 1234",
data: []byte("INVITE sip:1234@pbx.local SIP/2.0\r\n"),
expectDetection: false,
expectedPattern: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern := detectSuspiciousPattern(tt.data)
detected := pattern != ""
if detected != tt.expectDetection {
t.Errorf("Detection for %q: got detected=%v, want detected=%v (pattern=%q)",
tt.name, detected, tt.expectDetection, pattern)
}
if tt.expectDetection && pattern != tt.expectedPattern {
t.Errorf("Pattern for %q: got %q, want %q", tt.name, pattern, tt.expectedPattern)
}
})
}
}
// TestLegacyIsSuspiciousSIP verifies the legacy wrapper function
func TestLegacyIsSuspiciousSIP(t *testing.T) {
// Suspicious - should return true
if !isSuspiciousSIP([]byte("User-Agent: friendly-scanner")) {
t.Error("Should detect friendly-scanner as suspicious")
}
// Not suspicious - should return false
if isSuspiciousSIP([]byte("User-Agent: Linphone/4.5.0")) {
t.Error("Should NOT detect Linphone as suspicious")
}
}
// =============================================================================
// Complete SIP Message Tests - Real-world SIP traffic patterns
// =============================================================================
func TestLegitimateREGISTERMessage(t *testing.T) {
// A complete, legitimate REGISTER message from a typical SIP phone
// Note: Using proper CRLF line endings as per SIP RFC 3261
msg := []byte("REGISTER sip:1001@pbx.example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0\r\n" +
"Max-Forwards: 70\r\n" +
"From: \"John Smith\" <sip:1001@pbx.example.com>;tag=1\r\n" +
"To: <sip:1001@pbx.example.com>\r\n" +
"Call-ID: 1-1234@192.168.1.100\r\n" +
"CSeq: 1 REGISTER\r\n" +
"Contact: <sip:1001@192.168.1.100:5060>\r\n" +
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n" +
"Expires: 3600\r\n" +
"Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO\r\n" +
"Content-Length: 0\r\n" +
"\r\n")
// Should NOT be detected as suspicious
pattern := detectSuspiciousPattern(msg)
if pattern != "" {
t.Errorf("Legitimate REGISTER should NOT be flagged as suspicious, got pattern: %s", pattern)
}
// SIP method extraction should work
method := ExtractSIPMethod(msg)
if method != MethodREGISTER {
t.Errorf("Method extraction: got %v, want REGISTER", method)
}
// Extension extraction should work
ext := ExtractTargetExtension(msg)
if ext != "1001" {
t.Errorf("Extension extraction: got %q, want 1001", ext)
}
}
func TestLegitimateINVITEMessage(t *testing.T) {
// A complete, legitimate INVITE message for a call
msg := []byte(`INVITE sip:5002@pbx.example.com SIP/2.0
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-1234567
Max-Forwards: 70
From: "Alice" <sip:1001@pbx.example.com>;tag=abc123
To: <sip:5002@pbx.example.com>
Call-ID: call-8888@192.168.1.100
CSeq: 1 INVITE
Contact: <sip:1001@192.168.1.100:5060>
User-Agent: Grandstream GXP2170 1.0.11.10
Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO
Content-Type: application/sdp
Content-Length: 260
v=0
o=- 1234567890 1234567890 IN IP4 192.168.1.100
s=-
c=IN IP4 192.168.1.100
t=0 0
m=audio 10000 RTP/AVP 0 8 101
a=rtpmap:0 PCMU/8000
a=rtpmap:8 PCMA/8000
a=rtpmap:101 telephone-event/8000
a=fmtp:101 0-16
`)
// Should NOT be detected as suspicious
pattern := detectSuspiciousPattern(msg)
if pattern != "" {
t.Errorf("Legitimate INVITE should NOT be flagged as suspicious, got pattern: %s", pattern)
}
// SIP method extraction should work
method := ExtractSIPMethod(msg)
if method != MethodINVITE {
t.Errorf("Method extraction: got %v, want INVITE", method)
}
// Extension extraction should work
ext := ExtractTargetExtension(msg)
if ext != "5002" {
t.Errorf("Extension extraction: got %q, want 5002", ext)
}
}
func TestLegitimateOPTIONSKeepAlive(t *testing.T) {
// OPTIONS is commonly used for NAT keep-alive
msg := []byte(`OPTIONS sip:pbx.example.com SIP/2.0
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-ping-001
Max-Forwards: 70
From: <sip:1001@pbx.example.com>;tag=keepalive
To: <sip:pbx.example.com>
Call-ID: keepalive-12345@192.168.1.100
CSeq: 100 OPTIONS
User-Agent: Polycom/5.9.3
Accept: application/sdp
Content-Length: 0
`)
// Should NOT be detected as suspicious
pattern := detectSuspiciousPattern(msg)
if pattern != "" {
t.Errorf("Legitimate OPTIONS keep-alive should NOT be flagged, got pattern: %s", pattern)
}
method := ExtractSIPMethod(msg)
if method != MethodOPTIONS {
t.Errorf("Method extraction: got %v, want OPTIONS", method)
}
}
func TestLegitimate200OKResponse(t *testing.T) {
// A 200 OK response to REGISTER
msg := []byte(`SIP/2.0 200 OK
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0;received=192.168.1.100
From: "John Smith" <sip:1001@pbx.example.com>;tag=1
To: <sip:1001@pbx.example.com>;tag=as1234
Call-ID: 1-1234@192.168.1.100
CSeq: 1 REGISTER
Contact: <sip:1001@192.168.1.100:5060>;expires=3600
Date: Mon, 01 Jan 2024 12:00:00 GMT
Server: Asterisk PBX 18.0
Content-Length: 0
`)
// Should NOT be detected as suspicious (Server: Asterisk is NOT "asterisk pbx" scanner signature)
pattern := detectSuspiciousPattern(msg)
if pattern != "" && pattern != "asterisk-pbx-scanner" {
t.Errorf("200 OK response should NOT be flagged as suspicious, got pattern: %s", pattern)
}
}
// =============================================================================
// Helper Function - min()
// =============================================================================
func TestMinFunction(t *testing.T) {
tests := []struct {
a, b, expected int
}{
{1, 2, 1},
{2, 1, 1},
{5, 5, 5},
{0, 10, 0},
{-1, 1, -1},
}
for _, tt := range tests {
result := min(tt.a, tt.b)
if result != tt.expected {
t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
}
}
}
// =============================================================================
// SIPHandler Module Info Test
// =============================================================================
func TestSIPHandlerModuleInfo(t *testing.T) {
h := SIPHandler{}
info := h.CaddyModule()
if info.ID != "layer4.handlers.sip_guardian" {
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.handlers.sip_guardian")
}
if info.New == nil {
t.Error("Module New function should not be nil")
}
// Verify New() returns correct type
newModule := info.New()
if _, ok := newModule.(*SIPHandler); !ok {
t.Error("New() should return *SIPHandler")
}
}
func TestSIPMatcherModuleInfo(t *testing.T) {
m := SIPMatcher{}
info := m.CaddyModule()
if info.ID != "layer4.matchers.sip" {
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.matchers.sip")
}
if info.New == nil {
t.Error("Module New function should not be nil")
}
// Verify New() returns correct type
newModule := info.New()
if _, ok := newModule.(*SIPMatcher); !ok {
t.Error("New() should return *SIPMatcher")
}
}

698
ratelimit_test.go Normal file
View File

@ -0,0 +1,698 @@
package sipguardian
import (
"sync"
"testing"
"time"
"go.uber.org/zap"
)
// =============================================================================
// Rate Limiter Tests - Ensuring legitimate traffic flows through
// =============================================================================
func TestRateLimiterBasicAllow(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Configure a simple limit
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 5,
})
ip := "192.168.1.100"
// First 5 requests should be allowed (burst)
for i := 0; i < 5; i++ {
allowed, reason := rl.Allow(ip, MethodREGISTER)
if !allowed {
t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterBurstExhaustion(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 3,
})
ip := "192.168.1.101"
// Exhaust burst
for i := 0; i < 3; i++ {
allowed, _ := rl.Allow(ip, MethodREGISTER)
if !allowed {
t.Errorf("Request %d should be allowed within burst", i+1)
}
}
// Next request should be rate limited (burst exhausted)
allowed, reason := rl.Allow(ip, MethodREGISTER)
if allowed {
t.Error("Request after burst should be rate limited")
}
if reason != "rate_limit_REGISTER" {
t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason)
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// 60 requests per minute = 1 per second
rl.SetLimit(MethodOPTIONS, &MethodRateLimit{
Method: MethodOPTIONS,
MaxRequests: 60,
Window: time.Minute,
BurstSize: 2,
})
ip := "192.168.1.102"
// Exhaust burst
rl.Allow(ip, MethodOPTIONS)
rl.Allow(ip, MethodOPTIONS)
// Should be blocked now
allowed, _ := rl.Allow(ip, MethodOPTIONS)
if allowed {
t.Error("Should be rate limited after burst")
}
// Wait for token refill (slightly more than 1 second for 1 token)
time.Sleep(1100 * time.Millisecond)
// Should be allowed again
allowed, reason := rl.Allow(ip, MethodOPTIONS)
if !allowed {
t.Errorf("Should be allowed after token refill, reason: %s", reason)
}
}
func TestRateLimiterDifferentMethods(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Different limits for different methods
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 2,
})
rl.SetLimit(MethodINVITE, &MethodRateLimit{
Method: MethodINVITE,
MaxRequests: 30,
Window: time.Minute,
BurstSize: 5,
})
ip := "192.168.1.103"
// Exhaust REGISTER burst
rl.Allow(ip, MethodREGISTER)
rl.Allow(ip, MethodREGISTER)
// REGISTER should be blocked
allowed, _ := rl.Allow(ip, MethodREGISTER)
if allowed {
t.Error("REGISTER should be rate limited")
}
// But INVITE should still work (separate bucket)
for i := 0; i < 5; i++ {
allowed, reason := rl.Allow(ip, MethodINVITE)
if !allowed {
t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 2,
})
ip1 := "192.168.1.104"
ip2 := "192.168.1.105"
// Exhaust IP1's burst
rl.Allow(ip1, MethodREGISTER)
rl.Allow(ip1, MethodREGISTER)
allowed, _ := rl.Allow(ip1, MethodREGISTER)
if allowed {
t.Error("IP1 should be rate limited")
}
// IP2 should still work (separate bucket)
for i := 0; i < 2; i++ {
allowed, reason := rl.Allow(ip2, MethodREGISTER)
if !allowed {
t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterDefaultLimits(t *testing.T) {
logger := zap.NewNop()
// Reset global rate limiter to get a fresh instance with defaults applied
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
// GetRateLimiter applies default limits from DefaultMethodLimits
rl := GetRateLimiter(logger)
// Test that DefaultMethodLimits are applied
for method, expectedLimit := range DefaultMethodLimits {
limit := rl.GetLimit(method)
if limit.MaxRequests != expectedLimit.MaxRequests {
t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests)
}
}
}
func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) {
// NewRateLimiter creates a fresh limiter without default method limits
// This is intentional for testing and custom configurations
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Should return the global default (100), not method-specific defaults
limit := rl.GetLimit(MethodREGISTER)
if limit.MaxRequests != 100 {
t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests)
}
}
func TestRateLimiterConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 100,
Window: time.Minute,
BurstSize: 50,
})
var wg sync.WaitGroup
allowedCount := 0
var mu sync.Mutex
// Simulate 100 concurrent requests from different IPs
for i := 0; i < 100; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
ip := "192.168.1." + string(rune('0'+n%10))
allowed, _ := rl.Allow(ip, MethodREGISTER)
if allowed {
mu.Lock()
allowedCount++
mu.Unlock()
}
}(i)
}
wg.Wait()
// Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity)
if allowedCount < 90 {
t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount)
}
}
func TestRateLimiterCleanup(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Add some buckets
rl.Allow("192.168.1.1", MethodREGISTER)
rl.Allow("192.168.1.2", MethodREGISTER)
rl.Allow("192.168.1.3", MethodREGISTER)
stats := rl.GetStats()
trackedIPs := stats["tracked_ips"].(int)
if trackedIPs != 3 {
t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs)
}
// Cleanup shouldn't remove fresh entries
rl.Cleanup()
stats = rl.GetStats()
trackedIPs = stats["tracked_ips"].(int)
if trackedIPs != 3 {
t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs)
}
}
func TestRateLimiterStats(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 5,
})
rl.Allow("192.168.1.100", MethodREGISTER)
stats := rl.GetStats()
if _, ok := stats["tracked_ips"]; !ok {
t.Error("Stats should include tracked_ips")
}
if _, ok := stats["limits"]; !ok {
t.Error("Stats should include limits")
}
limits := stats["limits"].(map[string]interface{})
if _, ok := limits["REGISTER"]; !ok {
t.Error("Limits should include REGISTER config")
}
}
func TestGlobalRateLimiter(t *testing.T) {
logger := zap.NewNop()
// Reset global rate limiter for clean test
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
if rl == nil {
t.Fatal("GetRateLimiter should return non-nil")
}
// Should return same instance
rl2 := GetRateLimiter(logger)
if rl != rl2 {
t.Error("GetRateLimiter should return same global instance")
}
// Should have default limits applied
registerLimit := rl.GetLimit(MethodREGISTER)
if registerLimit.MaxRequests != 10 {
t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests)
}
}
// =============================================================================
// SIP Parsing Function Tests
// =============================================================================
func TestExtractSIPMethod(t *testing.T) {
tests := []struct {
name string
data []byte
expected SIPMethod
}{
{"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER},
{"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE},
{"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS},
{"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK},
{"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE},
{"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL},
{"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO},
{"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY},
{"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE},
{"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE},
{"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE},
{"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK},
{"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER},
{"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH},
{"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""},
{"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
method := ExtractSIPMethod(tt.data)
if method != tt.expected {
t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected)
}
})
}
}
// Note: TestExtractTargetExtension is already defined in enumeration_test.go
// These additional tests cover edge cases for rate limiter integration
func TestExtractTargetExtensionEdgeCases(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
// Legitimate extension extractions (additional cases)
{
name: "Simple 4-digit extension",
data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"),
expected: "1001",
},
{
name: "3-digit extension",
data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"),
expected: "200",
},
{
name: "Alphanumeric short name",
data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"),
expected: "john",
},
// Should NOT extract these as extensions (domain-like or too long)
{
name: "Full domain should not be extracted",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
expected: "", // Contains a dot, filtered out
},
{
name: "Too long identifier",
data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"),
expected: "", // >10 chars
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ext := ExtractTargetExtension(tt.data)
if ext != tt.expected {
t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected)
}
})
}
}
func TestParseUserAgent(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
{
name: "Yealink phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
expected: "Yealink SIP-T46S 66.86.0.15",
},
{
name: "Linphone",
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" +
"User-Agent: Linphone/4.5.0\r\n"),
expected: "Linphone/4.5.0",
},
{
name: "No User-Agent",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.1\r\n"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ua := ParseUserAgent(tt.data)
if ua != tt.expected {
t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected)
}
})
}
}
func TestParseFromHeader(t *testing.T) {
tests := []struct {
name string
data []byte
expectedUser string
expectedDomain string
}{
{
name: "Standard From header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"From: \"Alice\" <sip:alice@sip.example.com>;tag=1234\r\n"),
expectedUser: "alice",
expectedDomain: "sip.example.com",
},
{
name: "From header without display name",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"From: <sip:1001@pbx.local>;tag=abc\r\n"),
expectedUser: "1001",
expectedDomain: "pbx.local",
},
{
name: "No From header",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"To: <sip:bob@example.com>\r\n"),
expectedUser: "",
expectedDomain: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, domain := ParseFromHeader(tt.data)
if user != tt.expectedUser {
t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain)
}
})
}
}
func TestParseToHeader(t *testing.T) {
tests := []struct {
name string
data []byte
expectedUser string
expectedDomain string
}{
{
name: "Standard To header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"To: \"Bob\" <sip:bob@sip.example.com>\r\n"),
expectedUser: "bob",
expectedDomain: "sip.example.com",
},
{
name: "To header without display name",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"To: <sip:5002@pbx.local>\r\n"),
expectedUser: "5002",
expectedDomain: "pbx.local",
},
{
name: "No To header",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"From: <sip:alice@example.com>\r\n"),
expectedUser: "",
expectedDomain: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, domain := ParseToHeader(tt.data)
if user != tt.expectedUser {
t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain)
}
})
}
}
func TestParseCallID(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
{
name: "Standard Call-ID header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"Call-ID: 12345-67890@192.168.1.100\r\n"),
expected: "12345-67890@192.168.1.100",
},
{
name: "Short form Call-ID (i:)",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"i: compact-callid@host\r\n"),
expected: "compact-callid@host",
},
{
name: "No Call-ID",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"From: <sip:alice@example.com>\r\n"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
callID := ParseCallID(tt.data)
if callID != tt.expected {
t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected)
}
})
}
}
// =============================================================================
// Real-world Legitimate Traffic Scenarios
// =============================================================================
func TestLegitimatePhoneRegistration(t *testing.T) {
// Simulate a typical phone registration pattern:
// 1. Initial REGISTER (usually 401/407 challenge)
// 2. Second REGISTER with auth
// 3. Keep-alive OPTIONS
// 4. Re-REGISTER before expiry
logger := zap.NewNop()
// Reset global rate limiter for clean test
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.200"
// Simulate pattern over time
allowedCount := 0
// Initial REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// Auth REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// Keep-alive OPTIONS (phones do this frequently)
for i := 0; i < 5; i++ {
if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed {
allowedCount++
}
}
// Another REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// At minimum, the phone should complete registration (2 REGISTER + OPTIONS)
if allowedCount < 5 {
t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount)
}
}
func TestLegitimateCallFlow(t *testing.T) {
// Simulate a typical call flow:
// INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.201"
// INVITE
allowed, reason := rl.Allow(ip, MethodINVITE)
if !allowed {
t.Errorf("INVITE should be allowed for call: %s", reason)
}
// ACK (not rate limited by default)
allowed, reason = rl.Allow(ip, MethodACK)
if !allowed {
t.Errorf("ACK should be allowed for call: %s", reason)
}
// BYE
allowed, reason = rl.Allow(ip, MethodBYE)
if !allowed {
t.Errorf("BYE should be allowed for call: %s", reason)
}
}
func TestLegitimateSubscriptionFlow(t *testing.T) {
// Simulate presence/BLF subscription
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.202"
// SUBSCRIBE for presence (phones do multiple for BLF)
allowedCount := 0
for i := 0; i < 10; i++ {
if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed {
allowedCount++
}
}
// Default allows 5 burst + some refill
if allowedCount < 5 {
t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount)
}
}
func TestLegitimateMessaging(t *testing.T) {
// Simulate SMS/messaging traffic
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.203"
// MESSAGE (higher limit - 100/min default)
allowedCount := 0
for i := 0; i < 25; i++ {
if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed {
allowedCount++
}
}
// Should allow most messages (20 burst default)
if allowedCount < 20 {
t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount)
}
}

View File

@ -41,6 +41,11 @@ type SIPGuardian struct {
BanTime caddy.Duration `json:"ban_time,omitempty"`
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
// DNS-aware whitelist configuration
WhitelistHosts []string `json:"whitelist_hosts,omitempty"` // Hostnames to resolve (A/AAAA)
WhitelistSRV []string `json:"whitelist_srv,omitempty"` // SRV records to resolve
DNSRefresh caddy.Duration `json:"dns_refresh,omitempty"` // DNS refresh interval (default: 5m)
// Webhook configuration
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
@ -48,7 +53,7 @@ type SIPGuardian struct {
StoragePath string `json:"storage_path,omitempty"`
// GeoIP configuration
GeoIPPath string `json:"geoip_path,omitempty"`
GeoIPPath string `json:"geoip_path,omitempty"`
BlockedCountries []string `json:"blocked_countries,omitempty"`
AllowedCountries []string `json:"allowed_countries,omitempty"`
@ -63,6 +68,7 @@ type SIPGuardian struct {
bannedIPs map[string]*BanEntry
failureCounts map[string]*failureTracker
whitelistNets []*net.IPNet
dnsWhitelist *DNSWhitelist
mu sync.RWMutex
storage *Storage
geoIP *GeoIPLookup
@ -155,6 +161,34 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
}
}
// Initialize DNS whitelist if configured
if len(g.WhitelistHosts) > 0 || len(g.WhitelistSRV) > 0 {
refreshInterval := 5 * time.Minute
if g.DNSRefresh > 0 {
refreshInterval = time.Duration(g.DNSRefresh)
}
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
SRVRecords: g.WhitelistSRV,
RefreshInterval: refreshInterval,
AllowStale: true,
ResolveTimeout: 10 * time.Second,
}, g.logger)
if err := g.dnsWhitelist.Start(); err != nil {
g.logger.Warn("Failed to initialize DNS whitelist",
zap.Error(err),
)
} else {
g.logger.Info("DNS whitelist initialized",
zap.Int("hostnames", len(g.WhitelistHosts)),
zap.Int("srv_records", len(g.WhitelistSRV)),
zap.Duration("refresh_interval", refreshInterval),
)
}
}
// Initialize enumeration detection with config if specified
if g.Enumeration != nil {
SetEnumerationConfig(*g.Enumeration)
@ -216,12 +250,14 @@ func (g *SIPGuardian) loadBansFromStorage() error {
return nil
}
// IsWhitelisted checks if an IP is in the whitelist
// IsWhitelisted checks if an IP is in the whitelist (CIDR or DNS-based)
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// Check CIDR-based whitelist
for _, network := range g.whitelistNets {
if network.Contains(parsedIP) {
if enableMetrics {
@ -230,6 +266,19 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool {
return true
}
}
// Check DNS-based whitelist
if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) {
if enableMetrics {
RecordWhitelistedConnection()
}
g.logger.Debug("IP whitelisted via DNS",
zap.String("ip", ip),
zap.String("source", g.dnsWhitelist.GetSource(ip).Source),
)
return true
}
return false
}
@ -448,10 +497,34 @@ func (g *SIPGuardian) GetStats() map[string]interface{} {
}
}
return map[string]interface{}{
"active_bans": activeBans,
"tracked_failures": len(g.failureCounts),
"whitelist_count": len(g.whitelistNets),
stats := map[string]interface{}{
"active_bans": activeBans,
"tracked_failures": len(g.failureCounts),
"whitelist_cidr": len(g.whitelistNets),
"whitelist_hosts": len(g.WhitelistHosts),
"whitelist_srv": len(g.WhitelistSRV),
}
// Add DNS whitelist stats if available
if g.dnsWhitelist != nil {
stats["dns_whitelist"] = g.dnsWhitelist.Stats()
}
return stats
}
// GetDNSWhitelistEntries returns all resolved DNS whitelist entries
func (g *SIPGuardian) GetDNSWhitelistEntries() []ResolvedEntry {
if g.dnsWhitelist == nil {
return nil
}
return g.dnsWhitelist.GetResolvedIPs()
}
// RefreshDNSWhitelist forces an immediate refresh of DNS whitelist entries
func (g *SIPGuardian) RefreshDNSWhitelist() {
if g.dnsWhitelist != nil {
g.dnsWhitelist.ForceRefresh()
}
}
@ -500,8 +573,15 @@ func (g *SIPGuardian) cleanup() {
// max_failures 5
// find_time 10m
// ban_time 1h
//
// # IP/CIDR whitelist (static)
// whitelist 10.0.0.0/8 192.168.0.0/16
//
// # DNS-aware whitelist (dynamic, auto-refreshed)
// whitelist_hosts pbx.example.com trunk.provider.com
// whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
// dns_refresh 5m # How often to refresh DNS (default: 5m)
//
// # Persistent storage
// storage /var/lib/sip-guardian/guardian.db
//
@ -552,10 +632,34 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
g.BanTime = caddy.Duration(dur)
case "whitelist":
// Legacy: CIDR-only whitelist
for d.NextArg() {
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
}
case "whitelist_hosts":
// DNS A/AAAA record whitelist (hostnames resolved to IPs)
for d.NextArg() {
g.WhitelistHosts = append(g.WhitelistHosts, d.Val())
}
case "whitelist_srv":
// DNS SRV record whitelist (e.g., _sip._udp.provider.com)
for d.NextArg() {
g.WhitelistSRV = append(g.WhitelistSRV, d.Val())
}
case "dns_refresh":
// Interval for refreshing DNS-based whitelist entries
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid dns_refresh: %v", err)
}
g.DNSRefresh = caddy.Duration(dur)
case "storage":
if !d.NextArg() {
return d.ArgErr()