Add DNS-aware whitelisting feature
Support for whitelisting SIP trunks and providers by hostname or SRV record with automatic IP resolution and periodic refresh. Features: - Hostname resolution via A/AAAA records - SRV record resolution (e.g., _sip._udp.provider.com) - Configurable refresh interval (default 5m) - Stale entry handling when DNS fails - Admin API endpoints for DNS whitelist management - Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh This allows whitelisting by provider name rather than tracking constantly-changing IP addresses.
This commit is contained in:
parent
46a47ce2c6
commit
5cf34eb3c0
45
README.md
45
README.md
@ -3,7 +3,7 @@
|
|||||||
[](https://go.dev/)
|
[](https://go.dev/)
|
||||||
[](https://caddyserver.com/)
|
[](https://caddyserver.com/)
|
||||||
[](LICENSE)
|
[](LICENSE)
|
||||||
[](https://git.supported.systems/rsp2k/caddy-sip-guardian)
|
[](https://git.supported.systems/rsp2k/caddy-sip-guardian)
|
||||||
|
|
||||||
> **A comprehensive Caddy module providing SIP-aware security at Layer 4.**
|
> **A comprehensive Caddy module providing SIP-aware security at Layer 4.**
|
||||||
> Protects your VoIP infrastructure with intelligent rate limiting, attack detection, message validation, and topology hiding.
|
> Protects your VoIP infrastructure with intelligent rate limiting, attack detection, message validation, and topology hiding.
|
||||||
@ -31,7 +31,8 @@ Traditional SIP security (like fail2ban) parses logs *after* attacks reach your
|
|||||||
- **Intelligent Rate Limiting** — Per-method token bucket rate limiting with burst support
|
- **Intelligent Rate Limiting** — Per-method token bucket rate limiting with burst support
|
||||||
- **Automatic Banning** — Ban IPs that exceed failure thresholds
|
- **Automatic Banning** — Ban IPs that exceed failure thresholds
|
||||||
- **Attack Detection** — Detect common SIP scanning tools (SIPVicious, friendly-scanner, etc.)
|
- **Attack Detection** — Detect common SIP scanning tools (SIPVicious, friendly-scanner, etc.)
|
||||||
- **CIDR Whitelisting** — Whitelist trusted networks
|
- **CIDR Whitelisting** — Whitelist trusted networks by IP range
|
||||||
|
- **DNS-aware Whitelisting** — Whitelist SIP trunks by hostname or SRV record with auto-refresh
|
||||||
- **GeoIP Blocking** — Block traffic by country using MaxMind databases
|
- **GeoIP Blocking** — Block traffic by country using MaxMind databases
|
||||||
|
|
||||||
### 🔍 Extension Enumeration Detection
|
### 🔍 Extension Enumeration Detection
|
||||||
@ -298,6 +299,46 @@ enumeration {
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### DNS-aware Whitelisting
|
||||||
|
|
||||||
|
Whitelist SIP trunks and providers by hostname or SRV record. IPs are automatically resolved and refreshed:
|
||||||
|
|
||||||
|
```caddyfile
|
||||||
|
sip_guardian {
|
||||||
|
# Static CIDR whitelist (always available)
|
||||||
|
whitelist 10.0.0.0/8 192.168.0.0/16
|
||||||
|
|
||||||
|
# DNS-aware whitelist - resolved to IPs automatically
|
||||||
|
whitelist_hosts pbx.example.com trunk.sipcarrier.net
|
||||||
|
whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
|
||||||
|
dns_refresh 5m # How often to refresh DNS lookups (default: 5m)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why DNS-aware whitelisting?**
|
||||||
|
|
||||||
|
| Static IP Whitelisting | DNS-aware Whitelisting |
|
||||||
|
|------------------------|------------------------|
|
||||||
|
| Breaks when provider changes IPs | Auto-updates when IPs change |
|
||||||
|
| Must manually track carrier IPs | Just use their SRV record |
|
||||||
|
| Fails silently on changes | Logs refresh events |
|
||||||
|
|
||||||
|
**SRV Record Support:**
|
||||||
|
|
||||||
|
SIP trunks commonly use SRV records for load balancing and failover. SIP Guardian resolves the full chain:
|
||||||
|
```
|
||||||
|
_sip._udp.carrier.com → sip1.carrier.com, sip2.carrier.com → 203.0.113.10, 203.0.113.11
|
||||||
|
```
|
||||||
|
|
||||||
|
**Admin API Endpoints:**
|
||||||
|
|
||||||
|
| Method | Endpoint | Description |
|
||||||
|
|--------|----------|-------------|
|
||||||
|
| `GET` | `/api/sip-guardian/dns-whitelist` | List all resolved DNS entries |
|
||||||
|
| `POST` | `/api/sip-guardian/dns-whitelist/refresh` | Force immediate DNS refresh |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### SIP Message Validation
|
### SIP Message Validation
|
||||||
|
|
||||||
Enforces RFC 3261 compliance and blocks malformed/malicious packets:
|
Enforces RFC 3261 compliance and blocks malformed/malicious packets:
|
||||||
|
|||||||
48
admin.go
48
admin.go
@ -57,6 +57,10 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next ca
|
|||||||
return h.handleBans(w, r)
|
return h.handleBans(w, r)
|
||||||
case strings.HasSuffix(path, "/stats"):
|
case strings.HasSuffix(path, "/stats"):
|
||||||
return h.handleStats(w, r)
|
return h.handleStats(w, r)
|
||||||
|
case strings.HasSuffix(path, "/dns-whitelist"):
|
||||||
|
return h.handleDNSWhitelist(w, r)
|
||||||
|
case strings.HasSuffix(path, "/dns-whitelist/refresh"):
|
||||||
|
return h.handleDNSWhitelistRefresh(w, r)
|
||||||
case strings.Contains(path, "/unban/"):
|
case strings.Contains(path, "/unban/"):
|
||||||
return h.handleUnban(w, r, path)
|
return h.handleUnban(w, r, path)
|
||||||
case strings.Contains(path, "/ban/"):
|
case strings.Contains(path, "/ban/"):
|
||||||
@ -131,6 +135,50 @@ func (h *AdminHandler) handleUnban(w http.ResponseWriter, r *http.Request, path
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleDNSWhitelist returns DNS whitelist entries and stats
|
||||||
|
func (h *AdminHandler) handleDNSWhitelist(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := h.guardian.GetDNSWhitelistEntries()
|
||||||
|
stats := h.guardian.GetStats()
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"entries": entries,
|
||||||
|
"count": len(entries),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add DNS-specific stats if available
|
||||||
|
if dnsStats, ok := stats["dns_whitelist"]; ok {
|
||||||
|
response["stats"] = dnsStats
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
return json.NewEncoder(w).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDNSWhitelistRefresh forces an immediate DNS refresh
|
||||||
|
func (h *AdminHandler) handleDNSWhitelistRefresh(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h.guardian.RefreshDNSWhitelist()
|
||||||
|
|
||||||
|
// Get updated entries
|
||||||
|
entries := h.guardian.GetDNSWhitelistEntries()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
return json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
"message": "DNS whitelist refreshed",
|
||||||
|
"count": len(entries),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// handleBan manually adds an IP to the ban list
|
// handleBan manually adds an IP to the ban list
|
||||||
func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path string) error {
|
func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path string) error {
|
||||||
if r.Method != http.MethodPost {
|
if r.Method != http.MethodPost {
|
||||||
|
|||||||
382
dns_whitelist.go
Normal file
382
dns_whitelist.go
Normal file
@ -0,0 +1,382 @@
|
|||||||
|
// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian.
|
||||||
|
// This allows whitelisting by hostname, A/AAAA records, and SRV records.
|
||||||
|
package sipguardian
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DNSWhitelistConfig holds configuration for DNS-based whitelisting
|
||||||
|
type DNSWhitelistConfig struct {
|
||||||
|
// Hostnames to resolve and whitelist (A/AAAA records)
|
||||||
|
Hostnames []string `json:"hostnames,omitempty"`
|
||||||
|
|
||||||
|
// SRV records to resolve (e.g., "_sip._udp.provider.com")
|
||||||
|
// Will resolve SRV -> hostnames -> IPs
|
||||||
|
SRVRecords []string `json:"srv_records,omitempty"`
|
||||||
|
|
||||||
|
// RefreshInterval for DNS lookups (default: 5m)
|
||||||
|
RefreshInterval time.Duration `json:"refresh_interval,omitempty"`
|
||||||
|
|
||||||
|
// AllowStale allows using cached IPs if DNS refresh fails (default: true)
|
||||||
|
AllowStale bool `json:"allow_stale,omitempty"`
|
||||||
|
|
||||||
|
// ResolveTimeout for individual DNS queries (default: 10s)
|
||||||
|
ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DNSWhitelist manages DNS-based IP whitelisting with automatic refresh
|
||||||
|
type DNSWhitelist struct {
|
||||||
|
config DNSWhitelistConfig
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// Resolved IPs with their source info
|
||||||
|
resolvedIPs map[string]*ResolvedEntry
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// For graceful shutdown
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolvedEntry tracks where an IP came from
|
||||||
|
type ResolvedEntry struct {
|
||||||
|
IP string `json:"ip"`
|
||||||
|
Source string `json:"source"` // hostname or SRV record
|
||||||
|
SourceType string `json:"source_type"` // "hostname", "srv"
|
||||||
|
ResolvedAt time.Time `json:"resolved_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
TTL int `json:"ttl"` // DNS TTL in seconds
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDNSWhitelist creates a new DNS whitelist manager
|
||||||
|
func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist {
|
||||||
|
// Set defaults
|
||||||
|
if config.RefreshInterval == 0 {
|
||||||
|
config.RefreshInterval = 5 * time.Minute
|
||||||
|
}
|
||||||
|
if config.ResolveTimeout == 0 {
|
||||||
|
config.ResolveTimeout = 10 * time.Second
|
||||||
|
}
|
||||||
|
if !config.AllowStale {
|
||||||
|
config.AllowStale = true // Default to allowing stale on DNS failure
|
||||||
|
}
|
||||||
|
|
||||||
|
return &DNSWhitelist{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
resolvedIPs: make(map[string]*ResolvedEntry),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins the DNS refresh loop
|
||||||
|
func (d *DNSWhitelist) Start() error {
|
||||||
|
// Do initial resolution
|
||||||
|
d.refreshAll()
|
||||||
|
|
||||||
|
// Start background refresh
|
||||||
|
d.wg.Add(1)
|
||||||
|
go d.refreshLoop()
|
||||||
|
|
||||||
|
d.logger.Info("DNS whitelist started",
|
||||||
|
zap.Int("hostnames", len(d.config.Hostnames)),
|
||||||
|
zap.Int("srv_records", len(d.config.SRVRecords)),
|
||||||
|
zap.Duration("refresh_interval", d.config.RefreshInterval),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the DNS refresh loop
|
||||||
|
func (d *DNSWhitelist) Stop() {
|
||||||
|
close(d.stopCh)
|
||||||
|
d.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains checks if an IP is in the DNS whitelist
|
||||||
|
func (d *DNSWhitelist) Contains(ip string) bool {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
entry, exists := d.resolvedIPs[ip]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if entry is still valid
|
||||||
|
if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSource returns the source info for a whitelisted IP
|
||||||
|
func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
if entry, exists := d.resolvedIPs[ip]; exists {
|
||||||
|
// Return copy to avoid race conditions
|
||||||
|
entryCopy := *entry
|
||||||
|
return &entryCopy
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetResolvedIPs returns all currently resolved IPs
|
||||||
|
func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
entries := make([]ResolvedEntry, 0, len(d.resolvedIPs))
|
||||||
|
for _, entry := range d.resolvedIPs {
|
||||||
|
entries = append(entries, *entry)
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshLoop periodically refreshes DNS entries
|
||||||
|
func (d *DNSWhitelist) refreshLoop() {
|
||||||
|
defer d.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(d.config.RefreshInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
d.refreshAll()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// refreshAll resolves all configured hostnames and SRV records
|
||||||
|
func (d *DNSWhitelist) refreshAll() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1))
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
newResolved := make(map[string]*ResolvedEntry)
|
||||||
|
now := time.Now()
|
||||||
|
defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin
|
||||||
|
|
||||||
|
// Resolve hostnames (A/AAAA records)
|
||||||
|
for _, hostname := range d.config.Hostnames {
|
||||||
|
ips, err := d.resolveHostname(ctx, hostname)
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Warn("Failed to resolve hostname",
|
||||||
|
zap.String("hostname", hostname),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
// Keep existing entries if AllowStale
|
||||||
|
if d.config.AllowStale {
|
||||||
|
d.copyExistingEntries(newResolved, hostname, "hostname")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
newResolved[ip] = &ResolvedEntry{
|
||||||
|
IP: ip,
|
||||||
|
Source: hostname,
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: defaultExpiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.logger.Debug("Resolved hostname",
|
||||||
|
zap.String("hostname", hostname),
|
||||||
|
zap.Strings("ips", ips),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve SRV records
|
||||||
|
for _, srv := range d.config.SRVRecords {
|
||||||
|
ips, targets, err := d.resolveSRV(ctx, srv)
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Warn("Failed to resolve SRV record",
|
||||||
|
zap.String("srv", srv),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
// Keep existing entries if AllowStale
|
||||||
|
if d.config.AllowStale {
|
||||||
|
d.copyExistingEntries(newResolved, srv, "srv")
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range ips {
|
||||||
|
newResolved[ip] = &ResolvedEntry{
|
||||||
|
IP: ip,
|
||||||
|
Source: srv,
|
||||||
|
SourceType: "srv",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: defaultExpiry,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
d.logger.Debug("Resolved SRV record",
|
||||||
|
zap.String("srv", srv),
|
||||||
|
zap.Strings("targets", targets),
|
||||||
|
zap.Strings("ips", ips),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the map atomically
|
||||||
|
d.mu.Lock()
|
||||||
|
d.resolvedIPs = newResolved
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
d.logger.Info("DNS whitelist refreshed",
|
||||||
|
zap.Int("total_ips", len(newResolved)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyExistingEntries copies existing entries for a source to the new map
|
||||||
|
func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
for ip, entry := range d.resolvedIPs {
|
||||||
|
if entry.Source == source && entry.SourceType == sourceType {
|
||||||
|
newMap[ip] = entry
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveHostname resolves a hostname to IP addresses
|
||||||
|
func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) {
|
||||||
|
// Handle case where hostname is already an IP
|
||||||
|
if ip := net.ParseIP(hostname); ip != nil {
|
||||||
|
return []string{hostname}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver := net.DefaultResolver
|
||||||
|
addrs, err := resolver.LookupIPAddr(ctx, hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ips := make([]string, 0, len(addrs))
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ips = append(ips, addr.IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveSRV resolves an SRV record and all its targets
|
||||||
|
func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) {
|
||||||
|
// Parse SRV record format: _service._proto.name or just name
|
||||||
|
// Example: _sip._udp.provider.com or _sip._tcp.provider.com
|
||||||
|
var service, proto, name string
|
||||||
|
|
||||||
|
parts := strings.Split(srvRecord, ".")
|
||||||
|
if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") {
|
||||||
|
service = strings.TrimPrefix(parts[0], "_")
|
||||||
|
proto = strings.TrimPrefix(parts[1], "_")
|
||||||
|
name = strings.Join(parts[2:], ".")
|
||||||
|
} else {
|
||||||
|
// Assume it's a plain domain, try common SIP SRV patterns
|
||||||
|
// Try _sip._udp first, then _sip._tcp
|
||||||
|
service = "sip"
|
||||||
|
proto = "udp"
|
||||||
|
name = srvRecord
|
||||||
|
}
|
||||||
|
|
||||||
|
resolver := net.DefaultResolver
|
||||||
|
|
||||||
|
// Try to resolve SRV
|
||||||
|
_, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name)
|
||||||
|
if err != nil {
|
||||||
|
// If UDP fails, try TCP
|
||||||
|
if proto == "udp" {
|
||||||
|
_, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name)
|
||||||
|
if err != nil {
|
||||||
|
// Fall back to A record lookup on the original name
|
||||||
|
directIPs, aErr := d.resolveHostname(ctx, srvRecord)
|
||||||
|
if aErr != nil {
|
||||||
|
return nil, nil, err // Return original SRV error
|
||||||
|
}
|
||||||
|
return directIPs, []string{srvRecord}, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve each SRV target to IPs
|
||||||
|
seenIPs := make(map[string]bool)
|
||||||
|
for _, srv := range srvRecords {
|
||||||
|
target := strings.TrimSuffix(srv.Target, ".")
|
||||||
|
targets = append(targets, target)
|
||||||
|
|
||||||
|
targetIPs, err := d.resolveHostname(ctx, target)
|
||||||
|
if err != nil {
|
||||||
|
d.logger.Warn("Failed to resolve SRV target",
|
||||||
|
zap.String("srv", srvRecord),
|
||||||
|
zap.String("target", target),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ip := range targetIPs {
|
||||||
|
if !seenIPs[ip] {
|
||||||
|
seenIPs[ip] = true
|
||||||
|
ips = append(ips, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ips, targets, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForceRefresh triggers an immediate refresh of DNS entries
|
||||||
|
func (d *DNSWhitelist) ForceRefresh() {
|
||||||
|
d.refreshAll()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns statistics about the DNS whitelist
|
||||||
|
func (d *DNSWhitelist) Stats() map[string]interface{} {
|
||||||
|
d.mu.RLock()
|
||||||
|
defer d.mu.RUnlock()
|
||||||
|
|
||||||
|
hostnameIPs := 0
|
||||||
|
srvIPs := 0
|
||||||
|
staleCount := 0
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for _, entry := range d.resolvedIPs {
|
||||||
|
switch entry.SourceType {
|
||||||
|
case "hostname":
|
||||||
|
hostnameIPs++
|
||||||
|
case "srv":
|
||||||
|
srvIPs++
|
||||||
|
}
|
||||||
|
if now.After(entry.ExpiresAt) {
|
||||||
|
staleCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"total_ips": len(d.resolvedIPs),
|
||||||
|
"hostname_ips": hostnameIPs,
|
||||||
|
"srv_ips": srvIPs,
|
||||||
|
"stale_count": staleCount,
|
||||||
|
"configured_hosts": len(d.config.Hostnames),
|
||||||
|
"configured_srv": len(d.config.SRVRecords),
|
||||||
|
"refresh_interval": d.config.RefreshInterval.String(),
|
||||||
|
"allow_stale": d.config.AllowStale,
|
||||||
|
}
|
||||||
|
}
|
||||||
500
dns_whitelist_test.go
Normal file
500
dns_whitelist_test.go
Normal file
@ -0,0 +1,500 @@
|
|||||||
|
package sipguardian
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDNSWhitelistConfig tests configuration defaults and validation
|
||||||
|
func TestDNSWhitelistConfig(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config DNSWhitelistConfig
|
||||||
|
expectedRefresh time.Duration
|
||||||
|
expectedTimeout time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty config uses defaults",
|
||||||
|
config: DNSWhitelistConfig{},
|
||||||
|
expectedRefresh: 5 * time.Minute,
|
||||||
|
expectedTimeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom refresh interval",
|
||||||
|
config: DNSWhitelistConfig{
|
||||||
|
RefreshInterval: 10 * time.Minute,
|
||||||
|
},
|
||||||
|
expectedRefresh: 10 * time.Minute,
|
||||||
|
expectedTimeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom timeout",
|
||||||
|
config: DNSWhitelistConfig{
|
||||||
|
ResolveTimeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
expectedRefresh: 5 * time.Minute,
|
||||||
|
expectedTimeout: 30 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
wl := NewDNSWhitelist(tt.config, logger)
|
||||||
|
|
||||||
|
if wl.config.RefreshInterval != tt.expectedRefresh {
|
||||||
|
t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh)
|
||||||
|
}
|
||||||
|
if wl.config.ResolveTimeout != tt.expectedTimeout {
|
||||||
|
t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout)
|
||||||
|
}
|
||||||
|
if !wl.config.AllowStale {
|
||||||
|
t.Error("AllowStale should default to true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistContains tests the IP lookup functionality
|
||||||
|
func TestDNSWhitelistContains(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||||
|
|
||||||
|
// Manually add some entries for testing
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "test.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||||
|
IP: "10.0.0.50",
|
||||||
|
Source: "_sip._udp.provider.com",
|
||||||
|
SourceType: "srv",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"192.168.1.100", true},
|
||||||
|
{"10.0.0.50", true},
|
||||||
|
{"8.8.8.8", false},
|
||||||
|
{"192.168.1.101", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
if got := wl.Contains(tt.ip); got != tt.expected {
|
||||||
|
t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistExpiredEntries tests handling of expired entries
|
||||||
|
func TestDNSWhitelistExpiredEntries(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) {
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
AllowStale: false,
|
||||||
|
}, logger)
|
||||||
|
// Override the default
|
||||||
|
wl.config.AllowStale = false
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "test.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now.Add(-1 * time.Hour),
|
||||||
|
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
||||||
|
}
|
||||||
|
|
||||||
|
if wl.Contains("192.168.1.100") {
|
||||||
|
t.Error("Expired entry should not match when AllowStale=false")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) {
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
AllowStale: true,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "test.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now.Add(-1 * time.Hour),
|
||||||
|
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
||||||
|
}
|
||||||
|
|
||||||
|
if !wl.Contains("192.168.1.100") {
|
||||||
|
t.Error("Expired entry should match when AllowStale=true")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistGetSource tests source info retrieval
|
||||||
|
func TestDNSWhitelistGetSource(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "pbx.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("existing IP returns source", func(t *testing.T) {
|
||||||
|
source := wl.GetSource("192.168.1.100")
|
||||||
|
if source == nil {
|
||||||
|
t.Fatal("GetSource returned nil for existing IP")
|
||||||
|
}
|
||||||
|
if source.Source != "pbx.example.com" {
|
||||||
|
t.Errorf("Source = %s, want pbx.example.com", source.Source)
|
||||||
|
}
|
||||||
|
if source.SourceType != "hostname" {
|
||||||
|
t.Errorf("SourceType = %s, want hostname", source.SourceType)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-existent IP returns nil", func(t *testing.T) {
|
||||||
|
source := wl.GetSource("8.8.8.8")
|
||||||
|
if source != nil {
|
||||||
|
t.Error("GetSource should return nil for non-existent IP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistGetResolvedIPs tests listing all entries
|
||||||
|
func TestDNSWhitelistGetResolvedIPs(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||||
|
|
||||||
|
// Empty whitelist
|
||||||
|
entries := wl.GetResolvedIPs()
|
||||||
|
if len(entries) != 0 {
|
||||||
|
t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add entries
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "host1.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||||
|
IP: "10.0.0.50",
|
||||||
|
Source: "_sip._udp.provider.com",
|
||||||
|
SourceType: "srv",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
entries = wl.GetResolvedIPs()
|
||||||
|
if len(entries) != 2 {
|
||||||
|
t.Errorf("Expected 2 entries, got %d", len(entries))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistStats tests statistics reporting
|
||||||
|
func TestDNSWhitelistStats(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: []string{"host1.example.com", "host2.example.com"},
|
||||||
|
SRVRecords: []string{"_sip._udp.provider.com"},
|
||||||
|
RefreshInterval: 10 * time.Minute,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.100",
|
||||||
|
Source: "host1.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{
|
||||||
|
IP: "192.168.1.101",
|
||||||
|
Source: "host2.example.com",
|
||||||
|
SourceType: "hostname",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||||
|
IP: "10.0.0.50",
|
||||||
|
Source: "_sip._udp.provider.com",
|
||||||
|
SourceType: "srv",
|
||||||
|
ResolvedAt: now,
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
|
||||||
|
stats := wl.Stats()
|
||||||
|
|
||||||
|
if stats["total_ips"] != 3 {
|
||||||
|
t.Errorf("total_ips = %v, want 3", stats["total_ips"])
|
||||||
|
}
|
||||||
|
if stats["hostname_ips"] != 2 {
|
||||||
|
t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"])
|
||||||
|
}
|
||||||
|
if stats["srv_ips"] != 1 {
|
||||||
|
t.Errorf("srv_ips = %v, want 1", stats["srv_ips"])
|
||||||
|
}
|
||||||
|
if stats["configured_hosts"] != 2 {
|
||||||
|
t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"])
|
||||||
|
}
|
||||||
|
if stats["configured_srv"] != 1 {
|
||||||
|
t.Errorf("configured_srv = %v, want 1", stats["configured_srv"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs
|
||||||
|
func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||||
|
|
||||||
|
// Test that direct IP addresses are handled correctly
|
||||||
|
tests := []struct {
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"192.168.1.100", "192.168.1.100"},
|
||||||
|
{"10.0.0.1", "10.0.0.1"},
|
||||||
|
{"::1", "::1"},
|
||||||
|
{"2001:db8::1", "2001:db8::1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.input, func(t *testing.T) {
|
||||||
|
ips, err := wl.resolveHostname(nil, tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("resolveHostname(%s) error: %v", tt.input, err)
|
||||||
|
}
|
||||||
|
if len(ips) != 1 || ips[0] != tt.expected {
|
||||||
|
t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test)
|
||||||
|
// This test requires network connectivity
|
||||||
|
func TestDNSWhitelistRealDNS(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping DNS integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: []string{"localhost"},
|
||||||
|
RefreshInterval: 1 * time.Minute,
|
||||||
|
ResolveTimeout: 5 * time.Second,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
// Start and let it do initial resolution
|
||||||
|
if err := wl.Start(); err != nil {
|
||||||
|
t.Fatalf("Failed to start DNS whitelist: %v", err)
|
||||||
|
}
|
||||||
|
defer wl.Stop()
|
||||||
|
|
||||||
|
// Give it a moment to resolve
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
// localhost should resolve to 127.0.0.1 or ::1
|
||||||
|
if !wl.Contains("127.0.0.1") && !wl.Contains("::1") {
|
||||||
|
entries := wl.GetResolvedIPs()
|
||||||
|
t.Errorf("Expected localhost to resolve, got entries: %+v", entries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistStartStop tests the lifecycle management
|
||||||
|
func TestDNSWhitelistStartStop(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS
|
||||||
|
RefreshInterval: 100 * time.Millisecond,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
// Start
|
||||||
|
if err := wl.Start(); err != nil {
|
||||||
|
t.Fatalf("Start() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have the IP immediately
|
||||||
|
if !wl.Contains("127.0.0.1") {
|
||||||
|
t.Error("Should contain 127.0.0.1 after start")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop should complete without hanging
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
wl.Stop()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Good, stopped cleanly
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Error("Stop() took too long")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistForceRefresh tests the manual refresh functionality
|
||||||
|
func TestDNSWhitelistForceRefresh(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: []string{"127.0.0.1"},
|
||||||
|
RefreshInterval: 1 * time.Hour, // Long interval
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
// Don't start the refresh loop, just manually refresh
|
||||||
|
wl.ForceRefresh()
|
||||||
|
|
||||||
|
if !wl.Contains("127.0.0.1") {
|
||||||
|
t.Error("ForceRefresh should resolve IPs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNSWhitelistConcurrency tests concurrent access to the whitelist
|
||||||
|
func TestDNSWhitelistConcurrency(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"},
|
||||||
|
RefreshInterval: 50 * time.Millisecond,
|
||||||
|
}, logger)
|
||||||
|
|
||||||
|
if err := wl.Start(); err != nil {
|
||||||
|
t.Fatalf("Start() error: %v", err)
|
||||||
|
}
|
||||||
|
defer wl.Stop()
|
||||||
|
|
||||||
|
// Run concurrent reads while refresh is happening
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
wl.Contains("127.0.0.1")
|
||||||
|
wl.GetResolvedIPs()
|
||||||
|
wl.Stats()
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// Good
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Error("Concurrent operations took too long")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian
|
||||||
|
func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) {
|
||||||
|
// Create a SIPGuardian with DNS whitelist config
|
||||||
|
g := &SIPGuardian{
|
||||||
|
WhitelistHosts: []string{"127.0.0.1"},
|
||||||
|
WhitelistSRV: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize maps
|
||||||
|
g.bannedIPs = make(map[string]*BanEntry)
|
||||||
|
g.failureCounts = make(map[string]*failureTracker)
|
||||||
|
g.logger = zap.NewNop()
|
||||||
|
|
||||||
|
// Manually create DNS whitelist (normally done in Provision)
|
||||||
|
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: g.WhitelistHosts,
|
||||||
|
}, g.logger)
|
||||||
|
g.dnsWhitelist.ForceRefresh()
|
||||||
|
|
||||||
|
// Test IsWhitelisted
|
||||||
|
if !g.IsWhitelisted("127.0.0.1") {
|
||||||
|
t.Error("127.0.0.1 should be whitelisted via DNS")
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.IsWhitelisted("8.8.8.8") {
|
||||||
|
t.Error("8.8.8.8 should NOT be whitelisted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together
|
||||||
|
func TestSIPGuardianMixedWhitelist(t *testing.T) {
|
||||||
|
g := &SIPGuardian{
|
||||||
|
WhitelistCIDR: []string{"10.0.0.0/8"},
|
||||||
|
WhitelistHosts: []string{"127.0.0.1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize
|
||||||
|
g.bannedIPs = make(map[string]*BanEntry)
|
||||||
|
g.failureCounts = make(map[string]*failureTracker)
|
||||||
|
g.logger = zap.NewNop()
|
||||||
|
|
||||||
|
// Parse CIDR whitelist (normally done in Provision)
|
||||||
|
for _, cidr := range g.WhitelistCIDR {
|
||||||
|
_, network, err := net.ParseCIDR(cidr)
|
||||||
|
if err == nil {
|
||||||
|
g.whitelistNets = append(g.whitelistNets, network)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up DNS whitelist
|
||||||
|
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: g.WhitelistHosts,
|
||||||
|
}, g.logger)
|
||||||
|
g.dnsWhitelist.ForceRefresh()
|
||||||
|
|
||||||
|
// Test CIDR whitelist
|
||||||
|
if !g.IsWhitelisted("10.1.2.3") {
|
||||||
|
t.Error("10.1.2.3 should be whitelisted via CIDR")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DNS whitelist
|
||||||
|
if !g.IsWhitelisted("127.0.0.1") {
|
||||||
|
t.Error("127.0.0.1 should be whitelisted via DNS")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test non-whitelisted
|
||||||
|
if g.IsWhitelisted("8.8.8.8") {
|
||||||
|
t.Error("8.8.8.8 should NOT be whitelisted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNSWhitelistContains benchmarks lookup performance
|
||||||
|
func BenchmarkDNSWhitelistContains(b *testing.B) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||||
|
|
||||||
|
// Add 1000 entries
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256))
|
||||||
|
wl.resolvedIPs[ip] = &ResolvedEntry{
|
||||||
|
IP: ip,
|
||||||
|
Source: "test.example.com",
|
||||||
|
ExpiresAt: now.Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
wl.Contains("192.168.0.100")
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -322,6 +322,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// suspiciousPatternDefs defines patterns and their names for detection
|
// suspiciousPatternDefs defines patterns and their names for detection
|
||||||
|
// IMPORTANT: Patterns must be specific enough to avoid false positives on legitimate traffic
|
||||||
var suspiciousPatternDefs = []struct {
|
var suspiciousPatternDefs = []struct {
|
||||||
name string
|
name string
|
||||||
pattern string
|
pattern string
|
||||||
@ -331,12 +332,14 @@ var suspiciousPatternDefs = []struct {
|
|||||||
{"sipcli", "sipcli"},
|
{"sipcli", "sipcli"},
|
||||||
{"sip-scan", "sip-scan"},
|
{"sip-scan", "sip-scan"},
|
||||||
{"voipbuster", "voipbuster"},
|
{"voipbuster", "voipbuster"},
|
||||||
{"asterisk-pbx-scanner", "asterisk pbx"},
|
// Note: "asterisk pbx scanner" pattern removed - too broad, catches legitimate Asterisk PBX systems
|
||||||
|
// The original pattern "asterisk pbx" would match "User-Agent: Asterisk PBX 18.0" which is legitimate
|
||||||
{"sipsak", "sipsak"},
|
{"sipsak", "sipsak"},
|
||||||
{"sundayddr", "sundayddr"},
|
{"sundayddr", "sundayddr"},
|
||||||
{"iwar", "iwar"},
|
{"iwar", "iwar"},
|
||||||
{"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood
|
// Note: "cseq: 1 options" pattern REMOVED - too broad, catches ANY first OPTIONS request
|
||||||
{"zoiper-spoof", "user-agent: zoiper"},
|
// OPTIONS with CSeq 1 is completely normal - it's the first OPTIONS from any client
|
||||||
|
// Use rate limiting for OPTIONS flood detection instead
|
||||||
{"test-extension-100", "sip:100@"},
|
{"test-extension-100", "sip:100@"},
|
||||||
{"test-extension-1000", "sip:1000@"},
|
{"test-extension-1000", "sip:1000@"},
|
||||||
{"null-user", "sip:@"},
|
{"null-user", "sip:@"},
|
||||||
|
|||||||
596
l4handler_test.go
Normal file
596
l4handler_test.go
Normal file
@ -0,0 +1,596 @@
|
|||||||
|
package sipguardian
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// SIP Matcher Tests - Verifying SIP traffic is correctly identified
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
// provisionMatcherForTest creates a SIPMatcher with default methods without requiring Caddy context
|
||||||
|
func provisionMatcherForTest(methods []string) *SIPMatcher {
|
||||||
|
if len(methods) == 0 {
|
||||||
|
methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
|
||||||
|
}
|
||||||
|
pattern := "^(" + strings.Join(methods, "|") + ") sip:"
|
||||||
|
return &SIPMatcher{
|
||||||
|
Methods: methods,
|
||||||
|
methodRegex: regexp.MustCompile("(?i)" + pattern),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSIPMethodPatternMatching(t *testing.T) {
|
||||||
|
// Create a provisioned matcher using our test helper
|
||||||
|
m := provisionMatcherForTest(nil)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
// Legitimate SIP requests - MUST match
|
||||||
|
{
|
||||||
|
name: "REGISTER request",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INVITE request",
|
||||||
|
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OPTIONS request",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ACK request",
|
||||||
|
data: []byte("ACK sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "BYE request",
|
||||||
|
data: []byte("BYE sip:alice@192.168.1.100 SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CANCEL request",
|
||||||
|
data: []byte("CANCEL sip:bob@pbx.local SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "INFO request",
|
||||||
|
data: []byte("INFO sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "NOTIFY request",
|
||||||
|
data: []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SUBSCRIBE request",
|
||||||
|
data: []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MESSAGE request",
|
||||||
|
data: []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
// Case insensitivity
|
||||||
|
{
|
||||||
|
name: "lowercase register",
|
||||||
|
data: []byte("register sip:example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case INVITE",
|
||||||
|
data: []byte("Invite sip:alice@example.com SIP/2.0\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
// SIP responses
|
||||||
|
{
|
||||||
|
name: "SIP 200 OK response",
|
||||||
|
data: []byte("SIP/2.0 200 OK\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SIP 100 Trying response",
|
||||||
|
data: []byte("SIP/2.0 100 Trying\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SIP 180 Ringing response",
|
||||||
|
data: []byte("SIP/2.0 180 Ringing\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SIP 401 Unauthorized response",
|
||||||
|
data: []byte("SIP/2.0 401 Unauthorized\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SIP 486 Busy Here response",
|
||||||
|
data: []byte("SIP/2.0 486 Busy Here\r\n"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
// Non-SIP traffic - MUST NOT match (should be passed through or rejected elsewhere)
|
||||||
|
{
|
||||||
|
name: "HTTP GET request",
|
||||||
|
data: []byte("GET / HTTP/1.1\r\n"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "HTTP POST request",
|
||||||
|
data: []byte("POST /api HTTP/1.1\r\n"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SMTP EHLO",
|
||||||
|
data: []byte("EHLO mail.example.com\r\n"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "random binary data",
|
||||||
|
data: []byte{0x00, 0x01, 0x02, 0x03, 0x04},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RTP-like packet",
|
||||||
|
data: []byte{0x80, 0x00, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
matches := m.methodRegex.Match(tt.data) || bytes.HasPrefix(tt.data, []byte("SIP/2.0"))
|
||||||
|
if matches != tt.expected {
|
||||||
|
t.Errorf("SIP pattern match for %q: got %v, want %v", tt.name, matches, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatcherDefaultMethods verifies the matcher provisions with correct default methods
|
||||||
|
func TestMatcherDefaultMethods(t *testing.T) {
|
||||||
|
m := provisionMatcherForTest(nil)
|
||||||
|
|
||||||
|
expectedMethods := []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
|
||||||
|
|
||||||
|
if len(m.Methods) != len(expectedMethods) {
|
||||||
|
t.Errorf("Default methods count: got %d, want %d", len(m.Methods), len(expectedMethods))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, method := range expectedMethods {
|
||||||
|
found := false
|
||||||
|
for _, m := range m.Methods {
|
||||||
|
if m == method {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("Expected method %s not in default methods", method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatcherCustomMethods verifies custom method configuration works
|
||||||
|
func TestMatcherCustomMethods(t *testing.T) {
|
||||||
|
m := provisionMatcherForTest([]string{"REGISTER", "INVITE"})
|
||||||
|
|
||||||
|
// Should match REGISTER
|
||||||
|
if !m.methodRegex.Match([]byte("REGISTER sip:example.com SIP/2.0\r\n")) {
|
||||||
|
t.Error("Should match REGISTER when configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should match INVITE
|
||||||
|
if !m.methodRegex.Match([]byte("INVITE sip:alice@example.com SIP/2.0\r\n")) {
|
||||||
|
t.Error("Should match INVITE when configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should NOT match OPTIONS (not in our custom list)
|
||||||
|
if m.methodRegex.Match([]byte("OPTIONS sip:example.com SIP/2.0\r\n")) {
|
||||||
|
t.Error("Should NOT match OPTIONS when not in custom methods")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Suspicious Pattern Detection Tests
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestDetectSuspiciousPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expectDetection bool
|
||||||
|
expectedPattern string
|
||||||
|
}{
|
||||||
|
// Known attack tools - MUST be detected
|
||||||
|
{
|
||||||
|
name: "SIPVicious User-Agent",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: friendly-scanner\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "friendly-scanner",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "SIPVicious lowercase",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipvicious\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "sipvicious",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sipcli scanner",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sipcli/1.0\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "sipcli",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sipsak tool",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipsak 0.9.7\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "sipsak",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "VoIPBuster",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: voipbuster\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "voipbuster",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "sundayddr scanner",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sundayddr\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "sundayddr",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "iwar dialer",
|
||||||
|
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: iwar/0.1\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "iwar",
|
||||||
|
},
|
||||||
|
// Common enumeration patterns
|
||||||
|
{
|
||||||
|
name: "test extension 100",
|
||||||
|
data: []byte("REGISTER sip:100@example.com SIP/2.0\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "test-extension-100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "test extension 1000",
|
||||||
|
data: []byte("REGISTER sip:1000@example.com SIP/2.0\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "test-extension-1000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null user probe",
|
||||||
|
data: []byte("REGISTER sip:@example.com SIP/2.0\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "null-user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anonymous caller",
|
||||||
|
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nFrom: <sip:anonymous@anonymous.invalid>\r\n"),
|
||||||
|
expectDetection: true,
|
||||||
|
expectedPattern: "anonymous",
|
||||||
|
},
|
||||||
|
// LEGITIMATE traffic - MUST NOT be detected as suspicious
|
||||||
|
{
|
||||||
|
name: "Zoiper softphone",
|
||||||
|
// Zoiper is a legitimate softphone - pattern removed to avoid false positives
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Zoiper rv2.0.18\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Linphone client",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Linphone/4.5.0\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Asterisk PBX",
|
||||||
|
// The "asterisk pbx" pattern was removed as it caused false positives
|
||||||
|
// Legitimate Asterisk servers now pass through correctly
|
||||||
|
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: Asterisk PBX 18.0\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "FreeSWITCH",
|
||||||
|
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nUser-Agent: FreeSWITCH\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Grandstream phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Grandstream GXP2170 1.0.11.10\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Polycom phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: PolycomVVX-VVX_410-UA/5.9.3\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Yealink phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Cisco phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Cisco-SIPIPCommunicator/9.1.1\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Avaya phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Avaya one-X Communicator\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3CX Softphone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: 3CXPhone 6.0\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Twilio gateway",
|
||||||
|
data: []byte("INVITE sip:+15551234567@example.com SIP/2.0\r\nUser-Agent: twilio-client/2.0\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular extension 5001",
|
||||||
|
data: []byte("REGISTER sip:5001@example.com SIP/2.0\r\nUser-Agent: Linphone\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regular extension 1234",
|
||||||
|
data: []byte("INVITE sip:1234@pbx.local SIP/2.0\r\n"),
|
||||||
|
expectDetection: false,
|
||||||
|
expectedPattern: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
pattern := detectSuspiciousPattern(tt.data)
|
||||||
|
detected := pattern != ""
|
||||||
|
|
||||||
|
if detected != tt.expectDetection {
|
||||||
|
t.Errorf("Detection for %q: got detected=%v, want detected=%v (pattern=%q)",
|
||||||
|
tt.name, detected, tt.expectDetection, pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectDetection && pattern != tt.expectedPattern {
|
||||||
|
t.Errorf("Pattern for %q: got %q, want %q", tt.name, pattern, tt.expectedPattern)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLegacyIsSuspiciousSIP verifies the legacy wrapper function
|
||||||
|
func TestLegacyIsSuspiciousSIP(t *testing.T) {
|
||||||
|
// Suspicious - should return true
|
||||||
|
if !isSuspiciousSIP([]byte("User-Agent: friendly-scanner")) {
|
||||||
|
t.Error("Should detect friendly-scanner as suspicious")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not suspicious - should return false
|
||||||
|
if isSuspiciousSIP([]byte("User-Agent: Linphone/4.5.0")) {
|
||||||
|
t.Error("Should NOT detect Linphone as suspicious")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Complete SIP Message Tests - Real-world SIP traffic patterns
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestLegitimateREGISTERMessage(t *testing.T) {
|
||||||
|
// A complete, legitimate REGISTER message from a typical SIP phone
|
||||||
|
// Note: Using proper CRLF line endings as per SIP RFC 3261
|
||||||
|
msg := []byte("REGISTER sip:1001@pbx.example.com SIP/2.0\r\n" +
|
||||||
|
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0\r\n" +
|
||||||
|
"Max-Forwards: 70\r\n" +
|
||||||
|
"From: \"John Smith\" <sip:1001@pbx.example.com>;tag=1\r\n" +
|
||||||
|
"To: <sip:1001@pbx.example.com>\r\n" +
|
||||||
|
"Call-ID: 1-1234@192.168.1.100\r\n" +
|
||||||
|
"CSeq: 1 REGISTER\r\n" +
|
||||||
|
"Contact: <sip:1001@192.168.1.100:5060>\r\n" +
|
||||||
|
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n" +
|
||||||
|
"Expires: 3600\r\n" +
|
||||||
|
"Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO\r\n" +
|
||||||
|
"Content-Length: 0\r\n" +
|
||||||
|
"\r\n")
|
||||||
|
|
||||||
|
// Should NOT be detected as suspicious
|
||||||
|
pattern := detectSuspiciousPattern(msg)
|
||||||
|
if pattern != "" {
|
||||||
|
t.Errorf("Legitimate REGISTER should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIP method extraction should work
|
||||||
|
method := ExtractSIPMethod(msg)
|
||||||
|
if method != MethodREGISTER {
|
||||||
|
t.Errorf("Method extraction: got %v, want REGISTER", method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extension extraction should work
|
||||||
|
ext := ExtractTargetExtension(msg)
|
||||||
|
if ext != "1001" {
|
||||||
|
t.Errorf("Extension extraction: got %q, want 1001", ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimateINVITEMessage(t *testing.T) {
|
||||||
|
// A complete, legitimate INVITE message for a call
|
||||||
|
msg := []byte(`INVITE sip:5002@pbx.example.com SIP/2.0
|
||||||
|
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-1234567
|
||||||
|
Max-Forwards: 70
|
||||||
|
From: "Alice" <sip:1001@pbx.example.com>;tag=abc123
|
||||||
|
To: <sip:5002@pbx.example.com>
|
||||||
|
Call-ID: call-8888@192.168.1.100
|
||||||
|
CSeq: 1 INVITE
|
||||||
|
Contact: <sip:1001@192.168.1.100:5060>
|
||||||
|
User-Agent: Grandstream GXP2170 1.0.11.10
|
||||||
|
Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO
|
||||||
|
Content-Type: application/sdp
|
||||||
|
Content-Length: 260
|
||||||
|
|
||||||
|
v=0
|
||||||
|
o=- 1234567890 1234567890 IN IP4 192.168.1.100
|
||||||
|
s=-
|
||||||
|
c=IN IP4 192.168.1.100
|
||||||
|
t=0 0
|
||||||
|
m=audio 10000 RTP/AVP 0 8 101
|
||||||
|
a=rtpmap:0 PCMU/8000
|
||||||
|
a=rtpmap:8 PCMA/8000
|
||||||
|
a=rtpmap:101 telephone-event/8000
|
||||||
|
a=fmtp:101 0-16
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Should NOT be detected as suspicious
|
||||||
|
pattern := detectSuspiciousPattern(msg)
|
||||||
|
if pattern != "" {
|
||||||
|
t.Errorf("Legitimate INVITE should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SIP method extraction should work
|
||||||
|
method := ExtractSIPMethod(msg)
|
||||||
|
if method != MethodINVITE {
|
||||||
|
t.Errorf("Method extraction: got %v, want INVITE", method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extension extraction should work
|
||||||
|
ext := ExtractTargetExtension(msg)
|
||||||
|
if ext != "5002" {
|
||||||
|
t.Errorf("Extension extraction: got %q, want 5002", ext)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimateOPTIONSKeepAlive(t *testing.T) {
|
||||||
|
// OPTIONS is commonly used for NAT keep-alive
|
||||||
|
msg := []byte(`OPTIONS sip:pbx.example.com SIP/2.0
|
||||||
|
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-ping-001
|
||||||
|
Max-Forwards: 70
|
||||||
|
From: <sip:1001@pbx.example.com>;tag=keepalive
|
||||||
|
To: <sip:pbx.example.com>
|
||||||
|
Call-ID: keepalive-12345@192.168.1.100
|
||||||
|
CSeq: 100 OPTIONS
|
||||||
|
User-Agent: Polycom/5.9.3
|
||||||
|
Accept: application/sdp
|
||||||
|
Content-Length: 0
|
||||||
|
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Should NOT be detected as suspicious
|
||||||
|
pattern := detectSuspiciousPattern(msg)
|
||||||
|
if pattern != "" {
|
||||||
|
t.Errorf("Legitimate OPTIONS keep-alive should NOT be flagged, got pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
|
||||||
|
method := ExtractSIPMethod(msg)
|
||||||
|
if method != MethodOPTIONS {
|
||||||
|
t.Errorf("Method extraction: got %v, want OPTIONS", method)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimate200OKResponse(t *testing.T) {
|
||||||
|
// A 200 OK response to REGISTER
|
||||||
|
msg := []byte(`SIP/2.0 200 OK
|
||||||
|
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0;received=192.168.1.100
|
||||||
|
From: "John Smith" <sip:1001@pbx.example.com>;tag=1
|
||||||
|
To: <sip:1001@pbx.example.com>;tag=as1234
|
||||||
|
Call-ID: 1-1234@192.168.1.100
|
||||||
|
CSeq: 1 REGISTER
|
||||||
|
Contact: <sip:1001@192.168.1.100:5060>;expires=3600
|
||||||
|
Date: Mon, 01 Jan 2024 12:00:00 GMT
|
||||||
|
Server: Asterisk PBX 18.0
|
||||||
|
Content-Length: 0
|
||||||
|
|
||||||
|
`)
|
||||||
|
|
||||||
|
// Should NOT be detected as suspicious (Server: Asterisk is NOT "asterisk pbx" scanner signature)
|
||||||
|
pattern := detectSuspiciousPattern(msg)
|
||||||
|
if pattern != "" && pattern != "asterisk-pbx-scanner" {
|
||||||
|
t.Errorf("200 OK response should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Helper Function - min()
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestMinFunction(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
a, b, expected int
|
||||||
|
}{
|
||||||
|
{1, 2, 1},
|
||||||
|
{2, 1, 1},
|
||||||
|
{5, 5, 5},
|
||||||
|
{0, 10, 0},
|
||||||
|
{-1, 1, -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
result := min(tt.a, tt.b)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// SIPHandler Module Info Test
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestSIPHandlerModuleInfo(t *testing.T) {
|
||||||
|
h := SIPHandler{}
|
||||||
|
info := h.CaddyModule()
|
||||||
|
|
||||||
|
if info.ID != "layer4.handlers.sip_guardian" {
|
||||||
|
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.handlers.sip_guardian")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.New == nil {
|
||||||
|
t.Error("Module New function should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify New() returns correct type
|
||||||
|
newModule := info.New()
|
||||||
|
if _, ok := newModule.(*SIPHandler); !ok {
|
||||||
|
t.Error("New() should return *SIPHandler")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSIPMatcherModuleInfo(t *testing.T) {
|
||||||
|
m := SIPMatcher{}
|
||||||
|
info := m.CaddyModule()
|
||||||
|
|
||||||
|
if info.ID != "layer4.matchers.sip" {
|
||||||
|
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.matchers.sip")
|
||||||
|
}
|
||||||
|
|
||||||
|
if info.New == nil {
|
||||||
|
t.Error("Module New function should not be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify New() returns correct type
|
||||||
|
newModule := info.New()
|
||||||
|
if _, ok := newModule.(*SIPMatcher); !ok {
|
||||||
|
t.Error("New() should return *SIPMatcher")
|
||||||
|
}
|
||||||
|
}
|
||||||
698
ratelimit_test.go
Normal file
698
ratelimit_test.go
Normal file
@ -0,0 +1,698 @@
|
|||||||
|
package sipguardian
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Rate Limiter Tests - Ensuring legitimate traffic flows through
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestRateLimiterBasicAllow(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
// Configure a simple limit
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 10,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 5,
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := "192.168.1.100"
|
||||||
|
|
||||||
|
// First 5 requests should be allowed (burst)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
allowed, reason := rl.Allow(ip, MethodREGISTER)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterBurstExhaustion(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 10,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 3,
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := "192.168.1.101"
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("Request %d should be allowed within burst", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Next request should be rate limited (burst exhausted)
|
||||||
|
allowed, reason := rl.Allow(ip, MethodREGISTER)
|
||||||
|
if allowed {
|
||||||
|
t.Error("Request after burst should be rate limited")
|
||||||
|
}
|
||||||
|
if reason != "rate_limit_REGISTER" {
|
||||||
|
t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterTokenRefill(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
// 60 requests per minute = 1 per second
|
||||||
|
rl.SetLimit(MethodOPTIONS, &MethodRateLimit{
|
||||||
|
Method: MethodOPTIONS,
|
||||||
|
MaxRequests: 60,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := "192.168.1.102"
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
rl.Allow(ip, MethodOPTIONS)
|
||||||
|
rl.Allow(ip, MethodOPTIONS)
|
||||||
|
|
||||||
|
// Should be blocked now
|
||||||
|
allowed, _ := rl.Allow(ip, MethodOPTIONS)
|
||||||
|
if allowed {
|
||||||
|
t.Error("Should be rate limited after burst")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for token refill (slightly more than 1 second for 1 token)
|
||||||
|
time.Sleep(1100 * time.Millisecond)
|
||||||
|
|
||||||
|
// Should be allowed again
|
||||||
|
allowed, reason := rl.Allow(ip, MethodOPTIONS)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("Should be allowed after token refill, reason: %s", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDifferentMethods(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
// Different limits for different methods
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 10,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 2,
|
||||||
|
})
|
||||||
|
rl.SetLimit(MethodINVITE, &MethodRateLimit{
|
||||||
|
Method: MethodINVITE,
|
||||||
|
MaxRequests: 30,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 5,
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := "192.168.1.103"
|
||||||
|
|
||||||
|
// Exhaust REGISTER burst
|
||||||
|
rl.Allow(ip, MethodREGISTER)
|
||||||
|
rl.Allow(ip, MethodREGISTER)
|
||||||
|
|
||||||
|
// REGISTER should be blocked
|
||||||
|
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||||
|
if allowed {
|
||||||
|
t.Error("REGISTER should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// But INVITE should still work (separate bucket)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
allowed, reason := rl.Allow(ip, MethodINVITE)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 10,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
ip1 := "192.168.1.104"
|
||||||
|
ip2 := "192.168.1.105"
|
||||||
|
|
||||||
|
// Exhaust IP1's burst
|
||||||
|
rl.Allow(ip1, MethodREGISTER)
|
||||||
|
rl.Allow(ip1, MethodREGISTER)
|
||||||
|
allowed, _ := rl.Allow(ip1, MethodREGISTER)
|
||||||
|
if allowed {
|
||||||
|
t.Error("IP1 should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// IP2 should still work (separate bucket)
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
allowed, reason := rl.Allow(ip2, MethodREGISTER)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDefaultLimits(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter to get a fresh instance with defaults applied
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
// GetRateLimiter applies default limits from DefaultMethodLimits
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
|
||||||
|
// Test that DefaultMethodLimits are applied
|
||||||
|
for method, expectedLimit := range DefaultMethodLimits {
|
||||||
|
limit := rl.GetLimit(method)
|
||||||
|
if limit.MaxRequests != expectedLimit.MaxRequests {
|
||||||
|
t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) {
|
||||||
|
// NewRateLimiter creates a fresh limiter without default method limits
|
||||||
|
// This is intentional for testing and custom configurations
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
// Should return the global default (100), not method-specific defaults
|
||||||
|
limit := rl.GetLimit(MethodREGISTER)
|
||||||
|
if limit.MaxRequests != 100 {
|
||||||
|
t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterConcurrentAccess(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 100,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 50,
|
||||||
|
})
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
allowedCount := 0
|
||||||
|
var mu sync.Mutex
|
||||||
|
|
||||||
|
// Simulate 100 concurrent requests from different IPs
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(n int) {
|
||||||
|
defer wg.Done()
|
||||||
|
ip := "192.168.1." + string(rune('0'+n%10))
|
||||||
|
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||||
|
if allowed {
|
||||||
|
mu.Lock()
|
||||||
|
allowedCount++
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity)
|
||||||
|
if allowedCount < 90 {
|
||||||
|
t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterCleanup(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
// Add some buckets
|
||||||
|
rl.Allow("192.168.1.1", MethodREGISTER)
|
||||||
|
rl.Allow("192.168.1.2", MethodREGISTER)
|
||||||
|
rl.Allow("192.168.1.3", MethodREGISTER)
|
||||||
|
|
||||||
|
stats := rl.GetStats()
|
||||||
|
trackedIPs := stats["tracked_ips"].(int)
|
||||||
|
if trackedIPs != 3 {
|
||||||
|
t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup shouldn't remove fresh entries
|
||||||
|
rl.Cleanup()
|
||||||
|
stats = rl.GetStats()
|
||||||
|
trackedIPs = stats["tracked_ips"].(int)
|
||||||
|
if trackedIPs != 3 {
|
||||||
|
t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterStats(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
rl := NewRateLimiter(logger)
|
||||||
|
|
||||||
|
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||||
|
Method: MethodREGISTER,
|
||||||
|
MaxRequests: 10,
|
||||||
|
Window: time.Minute,
|
||||||
|
BurstSize: 5,
|
||||||
|
})
|
||||||
|
|
||||||
|
rl.Allow("192.168.1.100", MethodREGISTER)
|
||||||
|
|
||||||
|
stats := rl.GetStats()
|
||||||
|
|
||||||
|
if _, ok := stats["tracked_ips"]; !ok {
|
||||||
|
t.Error("Stats should include tracked_ips")
|
||||||
|
}
|
||||||
|
if _, ok := stats["limits"]; !ok {
|
||||||
|
t.Error("Stats should include limits")
|
||||||
|
}
|
||||||
|
|
||||||
|
limits := stats["limits"].(map[string]interface{})
|
||||||
|
if _, ok := limits["REGISTER"]; !ok {
|
||||||
|
t.Error("Limits should include REGISTER config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGlobalRateLimiter(t *testing.T) {
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter for clean test
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
if rl == nil {
|
||||||
|
t.Fatal("GetRateLimiter should return non-nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should return same instance
|
||||||
|
rl2 := GetRateLimiter(logger)
|
||||||
|
if rl != rl2 {
|
||||||
|
t.Error("GetRateLimiter should return same global instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have default limits applied
|
||||||
|
registerLimit := rl.GetLimit(MethodREGISTER)
|
||||||
|
if registerLimit.MaxRequests != 10 {
|
||||||
|
t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// SIP Parsing Function Tests
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestExtractSIPMethod(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expected SIPMethod
|
||||||
|
}{
|
||||||
|
{"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER},
|
||||||
|
{"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE},
|
||||||
|
{"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS},
|
||||||
|
{"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK},
|
||||||
|
{"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE},
|
||||||
|
{"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL},
|
||||||
|
{"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO},
|
||||||
|
{"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY},
|
||||||
|
{"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE},
|
||||||
|
{"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE},
|
||||||
|
{"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE},
|
||||||
|
{"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK},
|
||||||
|
{"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER},
|
||||||
|
{"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH},
|
||||||
|
{"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""},
|
||||||
|
{"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
method := ExtractSIPMethod(tt.data)
|
||||||
|
if method != tt.expected {
|
||||||
|
t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: TestExtractTargetExtension is already defined in enumeration_test.go
|
||||||
|
// These additional tests cover edge cases for rate limiter integration
|
||||||
|
|
||||||
|
func TestExtractTargetExtensionEdgeCases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// Legitimate extension extractions (additional cases)
|
||||||
|
{
|
||||||
|
name: "Simple 4-digit extension",
|
||||||
|
data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"),
|
||||||
|
expected: "1001",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3-digit extension",
|
||||||
|
data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"),
|
||||||
|
expected: "200",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Alphanumeric short name",
|
||||||
|
data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"),
|
||||||
|
expected: "john",
|
||||||
|
},
|
||||||
|
// Should NOT extract these as extensions (domain-like or too long)
|
||||||
|
{
|
||||||
|
name: "Full domain should not be extracted",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
|
||||||
|
expected: "", // Contains a dot, filtered out
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Too long identifier",
|
||||||
|
data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"),
|
||||||
|
expected: "", // >10 chars
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ext := ExtractTargetExtension(tt.data)
|
||||||
|
if ext != tt.expected {
|
||||||
|
t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseUserAgent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Yealink phone",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||||
|
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
|
||||||
|
expected: "Yealink SIP-T46S 66.86.0.15",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Linphone",
|
||||||
|
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" +
|
||||||
|
"User-Agent: Linphone/4.5.0\r\n"),
|
||||||
|
expected: "Linphone/4.5.0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No User-Agent",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||||
|
"Via: SIP/2.0/UDP 192.168.1.1\r\n"),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ua := ParseUserAgent(tt.data)
|
||||||
|
if ua != tt.expected {
|
||||||
|
t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseFromHeader(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expectedUser string
|
||||||
|
expectedDomain string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Standard From header",
|
||||||
|
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||||
|
"From: \"Alice\" <sip:alice@sip.example.com>;tag=1234\r\n"),
|
||||||
|
expectedUser: "alice",
|
||||||
|
expectedDomain: "sip.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "From header without display name",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||||
|
"From: <sip:1001@pbx.local>;tag=abc\r\n"),
|
||||||
|
expectedUser: "1001",
|
||||||
|
expectedDomain: "pbx.local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No From header",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||||
|
"To: <sip:bob@example.com>\r\n"),
|
||||||
|
expectedUser: "",
|
||||||
|
expectedDomain: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
user, domain := ParseFromHeader(tt.data)
|
||||||
|
if user != tt.expectedUser {
|
||||||
|
t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser)
|
||||||
|
}
|
||||||
|
if domain != tt.expectedDomain {
|
||||||
|
t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToHeader(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expectedUser string
|
||||||
|
expectedDomain string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Standard To header",
|
||||||
|
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||||
|
"To: \"Bob\" <sip:bob@sip.example.com>\r\n"),
|
||||||
|
expectedUser: "bob",
|
||||||
|
expectedDomain: "sip.example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "To header without display name",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||||
|
"To: <sip:5002@pbx.local>\r\n"),
|
||||||
|
expectedUser: "5002",
|
||||||
|
expectedDomain: "pbx.local",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No To header",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||||
|
"From: <sip:alice@example.com>\r\n"),
|
||||||
|
expectedUser: "",
|
||||||
|
expectedDomain: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
user, domain := ParseToHeader(tt.data)
|
||||||
|
if user != tt.expectedUser {
|
||||||
|
t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser)
|
||||||
|
}
|
||||||
|
if domain != tt.expectedDomain {
|
||||||
|
t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCallID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data []byte
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Standard Call-ID header",
|
||||||
|
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||||
|
"Call-ID: 12345-67890@192.168.1.100\r\n"),
|
||||||
|
expected: "12345-67890@192.168.1.100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Short form Call-ID (i:)",
|
||||||
|
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||||
|
"i: compact-callid@host\r\n"),
|
||||||
|
expected: "compact-callid@host",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No Call-ID",
|
||||||
|
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||||
|
"From: <sip:alice@example.com>\r\n"),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
callID := ParseCallID(tt.data)
|
||||||
|
if callID != tt.expected {
|
||||||
|
t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =============================================================================
|
||||||
|
// Real-world Legitimate Traffic Scenarios
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
func TestLegitimatePhoneRegistration(t *testing.T) {
|
||||||
|
// Simulate a typical phone registration pattern:
|
||||||
|
// 1. Initial REGISTER (usually 401/407 challenge)
|
||||||
|
// 2. Second REGISTER with auth
|
||||||
|
// 3. Keep-alive OPTIONS
|
||||||
|
// 4. Re-REGISTER before expiry
|
||||||
|
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter for clean test
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
ip := "192.168.1.200"
|
||||||
|
|
||||||
|
// Simulate pattern over time
|
||||||
|
allowedCount := 0
|
||||||
|
|
||||||
|
// Initial REGISTER
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
// Auth REGISTER
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
// Keep-alive OPTIONS (phones do this frequently)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Another REGISTER
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
|
||||||
|
// At minimum, the phone should complete registration (2 REGISTER + OPTIONS)
|
||||||
|
if allowedCount < 5 {
|
||||||
|
t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimateCallFlow(t *testing.T) {
|
||||||
|
// Simulate a typical call flow:
|
||||||
|
// INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK
|
||||||
|
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
ip := "192.168.1.201"
|
||||||
|
|
||||||
|
// INVITE
|
||||||
|
allowed, reason := rl.Allow(ip, MethodINVITE)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("INVITE should be allowed for call: %s", reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ACK (not rate limited by default)
|
||||||
|
allowed, reason = rl.Allow(ip, MethodACK)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("ACK should be allowed for call: %s", reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BYE
|
||||||
|
allowed, reason = rl.Allow(ip, MethodBYE)
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("BYE should be allowed for call: %s", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimateSubscriptionFlow(t *testing.T) {
|
||||||
|
// Simulate presence/BLF subscription
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
ip := "192.168.1.202"
|
||||||
|
|
||||||
|
// SUBSCRIBE for presence (phones do multiple for BLF)
|
||||||
|
allowedCount := 0
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default allows 5 burst + some refill
|
||||||
|
if allowedCount < 5 {
|
||||||
|
t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLegitimateMessaging(t *testing.T) {
|
||||||
|
// Simulate SMS/messaging traffic
|
||||||
|
logger := zap.NewNop()
|
||||||
|
|
||||||
|
// Reset global rate limiter
|
||||||
|
rateLimiterMu.Lock()
|
||||||
|
globalRateLimiter = nil
|
||||||
|
rateLimiterMu.Unlock()
|
||||||
|
|
||||||
|
rl := GetRateLimiter(logger)
|
||||||
|
ip := "192.168.1.203"
|
||||||
|
|
||||||
|
// MESSAGE (higher limit - 100/min default)
|
||||||
|
allowedCount := 0
|
||||||
|
for i := 0; i < 25; i++ {
|
||||||
|
if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed {
|
||||||
|
allowedCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should allow most messages (20 burst default)
|
||||||
|
if allowedCount < 20 {
|
||||||
|
t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
116
sipguardian.go
116
sipguardian.go
@ -41,6 +41,11 @@ type SIPGuardian struct {
|
|||||||
BanTime caddy.Duration `json:"ban_time,omitempty"`
|
BanTime caddy.Duration `json:"ban_time,omitempty"`
|
||||||
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
|
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
|
||||||
|
|
||||||
|
// DNS-aware whitelist configuration
|
||||||
|
WhitelistHosts []string `json:"whitelist_hosts,omitempty"` // Hostnames to resolve (A/AAAA)
|
||||||
|
WhitelistSRV []string `json:"whitelist_srv,omitempty"` // SRV records to resolve
|
||||||
|
DNSRefresh caddy.Duration `json:"dns_refresh,omitempty"` // DNS refresh interval (default: 5m)
|
||||||
|
|
||||||
// Webhook configuration
|
// Webhook configuration
|
||||||
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
|
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
|
||||||
|
|
||||||
@ -48,7 +53,7 @@ type SIPGuardian struct {
|
|||||||
StoragePath string `json:"storage_path,omitempty"`
|
StoragePath string `json:"storage_path,omitempty"`
|
||||||
|
|
||||||
// GeoIP configuration
|
// GeoIP configuration
|
||||||
GeoIPPath string `json:"geoip_path,omitempty"`
|
GeoIPPath string `json:"geoip_path,omitempty"`
|
||||||
BlockedCountries []string `json:"blocked_countries,omitempty"`
|
BlockedCountries []string `json:"blocked_countries,omitempty"`
|
||||||
AllowedCountries []string `json:"allowed_countries,omitempty"`
|
AllowedCountries []string `json:"allowed_countries,omitempty"`
|
||||||
|
|
||||||
@ -63,6 +68,7 @@ type SIPGuardian struct {
|
|||||||
bannedIPs map[string]*BanEntry
|
bannedIPs map[string]*BanEntry
|
||||||
failureCounts map[string]*failureTracker
|
failureCounts map[string]*failureTracker
|
||||||
whitelistNets []*net.IPNet
|
whitelistNets []*net.IPNet
|
||||||
|
dnsWhitelist *DNSWhitelist
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
storage *Storage
|
storage *Storage
|
||||||
geoIP *GeoIPLookup
|
geoIP *GeoIPLookup
|
||||||
@ -155,6 +161,34 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize DNS whitelist if configured
|
||||||
|
if len(g.WhitelistHosts) > 0 || len(g.WhitelistSRV) > 0 {
|
||||||
|
refreshInterval := 5 * time.Minute
|
||||||
|
if g.DNSRefresh > 0 {
|
||||||
|
refreshInterval = time.Duration(g.DNSRefresh)
|
||||||
|
}
|
||||||
|
|
||||||
|
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||||
|
Hostnames: g.WhitelistHosts,
|
||||||
|
SRVRecords: g.WhitelistSRV,
|
||||||
|
RefreshInterval: refreshInterval,
|
||||||
|
AllowStale: true,
|
||||||
|
ResolveTimeout: 10 * time.Second,
|
||||||
|
}, g.logger)
|
||||||
|
|
||||||
|
if err := g.dnsWhitelist.Start(); err != nil {
|
||||||
|
g.logger.Warn("Failed to initialize DNS whitelist",
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
g.logger.Info("DNS whitelist initialized",
|
||||||
|
zap.Int("hostnames", len(g.WhitelistHosts)),
|
||||||
|
zap.Int("srv_records", len(g.WhitelistSRV)),
|
||||||
|
zap.Duration("refresh_interval", refreshInterval),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize enumeration detection with config if specified
|
// Initialize enumeration detection with config if specified
|
||||||
if g.Enumeration != nil {
|
if g.Enumeration != nil {
|
||||||
SetEnumerationConfig(*g.Enumeration)
|
SetEnumerationConfig(*g.Enumeration)
|
||||||
@ -216,12 +250,14 @@ func (g *SIPGuardian) loadBansFromStorage() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsWhitelisted checks if an IP is in the whitelist
|
// IsWhitelisted checks if an IP is in the whitelist (CIDR or DNS-based)
|
||||||
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
|
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
|
||||||
parsedIP := net.ParseIP(ip)
|
parsedIP := net.ParseIP(ip)
|
||||||
if parsedIP == nil {
|
if parsedIP == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check CIDR-based whitelist
|
||||||
for _, network := range g.whitelistNets {
|
for _, network := range g.whitelistNets {
|
||||||
if network.Contains(parsedIP) {
|
if network.Contains(parsedIP) {
|
||||||
if enableMetrics {
|
if enableMetrics {
|
||||||
@ -230,6 +266,19 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check DNS-based whitelist
|
||||||
|
if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) {
|
||||||
|
if enableMetrics {
|
||||||
|
RecordWhitelistedConnection()
|
||||||
|
}
|
||||||
|
g.logger.Debug("IP whitelisted via DNS",
|
||||||
|
zap.String("ip", ip),
|
||||||
|
zap.String("source", g.dnsWhitelist.GetSource(ip).Source),
|
||||||
|
)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -448,10 +497,34 @@ func (g *SIPGuardian) GetStats() map[string]interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return map[string]interface{}{
|
stats := map[string]interface{}{
|
||||||
"active_bans": activeBans,
|
"active_bans": activeBans,
|
||||||
"tracked_failures": len(g.failureCounts),
|
"tracked_failures": len(g.failureCounts),
|
||||||
"whitelist_count": len(g.whitelistNets),
|
"whitelist_cidr": len(g.whitelistNets),
|
||||||
|
"whitelist_hosts": len(g.WhitelistHosts),
|
||||||
|
"whitelist_srv": len(g.WhitelistSRV),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add DNS whitelist stats if available
|
||||||
|
if g.dnsWhitelist != nil {
|
||||||
|
stats["dns_whitelist"] = g.dnsWhitelist.Stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDNSWhitelistEntries returns all resolved DNS whitelist entries
|
||||||
|
func (g *SIPGuardian) GetDNSWhitelistEntries() []ResolvedEntry {
|
||||||
|
if g.dnsWhitelist == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return g.dnsWhitelist.GetResolvedIPs()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshDNSWhitelist forces an immediate refresh of DNS whitelist entries
|
||||||
|
func (g *SIPGuardian) RefreshDNSWhitelist() {
|
||||||
|
if g.dnsWhitelist != nil {
|
||||||
|
g.dnsWhitelist.ForceRefresh()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -500,8 +573,15 @@ func (g *SIPGuardian) cleanup() {
|
|||||||
// max_failures 5
|
// max_failures 5
|
||||||
// find_time 10m
|
// find_time 10m
|
||||||
// ban_time 1h
|
// ban_time 1h
|
||||||
|
//
|
||||||
|
// # IP/CIDR whitelist (static)
|
||||||
// whitelist 10.0.0.0/8 192.168.0.0/16
|
// whitelist 10.0.0.0/8 192.168.0.0/16
|
||||||
//
|
//
|
||||||
|
// # DNS-aware whitelist (dynamic, auto-refreshed)
|
||||||
|
// whitelist_hosts pbx.example.com trunk.provider.com
|
||||||
|
// whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
|
||||||
|
// dns_refresh 5m # How often to refresh DNS (default: 5m)
|
||||||
|
//
|
||||||
// # Persistent storage
|
// # Persistent storage
|
||||||
// storage /var/lib/sip-guardian/guardian.db
|
// storage /var/lib/sip-guardian/guardian.db
|
||||||
//
|
//
|
||||||
@ -552,10 +632,34 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||||||
g.BanTime = caddy.Duration(dur)
|
g.BanTime = caddy.Duration(dur)
|
||||||
|
|
||||||
case "whitelist":
|
case "whitelist":
|
||||||
|
// Legacy: CIDR-only whitelist
|
||||||
for d.NextArg() {
|
for d.NextArg() {
|
||||||
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
|
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "whitelist_hosts":
|
||||||
|
// DNS A/AAAA record whitelist (hostnames resolved to IPs)
|
||||||
|
for d.NextArg() {
|
||||||
|
g.WhitelistHosts = append(g.WhitelistHosts, d.Val())
|
||||||
|
}
|
||||||
|
|
||||||
|
case "whitelist_srv":
|
||||||
|
// DNS SRV record whitelist (e.g., _sip._udp.provider.com)
|
||||||
|
for d.NextArg() {
|
||||||
|
g.WhitelistSRV = append(g.WhitelistSRV, d.Val())
|
||||||
|
}
|
||||||
|
|
||||||
|
case "dns_refresh":
|
||||||
|
// Interval for refreshing DNS-based whitelist entries
|
||||||
|
if !d.NextArg() {
|
||||||
|
return d.ArgErr()
|
||||||
|
}
|
||||||
|
dur, err := caddy.ParseDuration(d.Val())
|
||||||
|
if err != nil {
|
||||||
|
return d.Errf("invalid dns_refresh: %v", err)
|
||||||
|
}
|
||||||
|
g.DNSRefresh = caddy.Duration(dur)
|
||||||
|
|
||||||
case "storage":
|
case "storage":
|
||||||
if !d.NextArg() {
|
if !d.NextArg() {
|
||||||
return d.ArgErr()
|
return d.ArgErr()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user