Add DNS-aware whitelisting feature
Support for whitelisting SIP trunks and providers by hostname or SRV record with automatic IP resolution and periodic refresh. Features: - Hostname resolution via A/AAAA records - SRV record resolution (e.g., _sip._udp.provider.com) - Configurable refresh interval (default 5m) - Stale entry handling when DNS fails - Admin API endpoints for DNS whitelist management - Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh This allows whitelisting by provider name rather than tracking constantly-changing IP addresses.
This commit is contained in:
parent
46a47ce2c6
commit
5cf34eb3c0
45
README.md
45
README.md
@ -3,7 +3,7 @@
|
||||
[](https://go.dev/)
|
||||
[](https://caddyserver.com/)
|
||||
[](LICENSE)
|
||||
[](https://git.supported.systems/rsp2k/caddy-sip-guardian)
|
||||
[](https://git.supported.systems/rsp2k/caddy-sip-guardian)
|
||||
|
||||
> **A comprehensive Caddy module providing SIP-aware security at Layer 4.**
|
||||
> Protects your VoIP infrastructure with intelligent rate limiting, attack detection, message validation, and topology hiding.
|
||||
@ -31,7 +31,8 @@ Traditional SIP security (like fail2ban) parses logs *after* attacks reach your
|
||||
- **Intelligent Rate Limiting** — Per-method token bucket rate limiting with burst support
|
||||
- **Automatic Banning** — Ban IPs that exceed failure thresholds
|
||||
- **Attack Detection** — Detect common SIP scanning tools (SIPVicious, friendly-scanner, etc.)
|
||||
- **CIDR Whitelisting** — Whitelist trusted networks
|
||||
- **CIDR Whitelisting** — Whitelist trusted networks by IP range
|
||||
- **DNS-aware Whitelisting** — Whitelist SIP trunks by hostname or SRV record with auto-refresh
|
||||
- **GeoIP Blocking** — Block traffic by country using MaxMind databases
|
||||
|
||||
### 🔍 Extension Enumeration Detection
|
||||
@ -298,6 +299,46 @@ enumeration {
|
||||
|
||||
---
|
||||
|
||||
### DNS-aware Whitelisting
|
||||
|
||||
Whitelist SIP trunks and providers by hostname or SRV record. IPs are automatically resolved and refreshed:
|
||||
|
||||
```caddyfile
|
||||
sip_guardian {
|
||||
# Static CIDR whitelist (always available)
|
||||
whitelist 10.0.0.0/8 192.168.0.0/16
|
||||
|
||||
# DNS-aware whitelist - resolved to IPs automatically
|
||||
whitelist_hosts pbx.example.com trunk.sipcarrier.net
|
||||
whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
|
||||
dns_refresh 5m # How often to refresh DNS lookups (default: 5m)
|
||||
}
|
||||
```
|
||||
|
||||
**Why DNS-aware whitelisting?**
|
||||
|
||||
| Static IP Whitelisting | DNS-aware Whitelisting |
|
||||
|------------------------|------------------------|
|
||||
| Breaks when provider changes IPs | Auto-updates when IPs change |
|
||||
| Must manually track carrier IPs | Just use their SRV record |
|
||||
| Fails silently on changes | Logs refresh events |
|
||||
|
||||
**SRV Record Support:**
|
||||
|
||||
SIP trunks commonly use SRV records for load balancing and failover. SIP Guardian resolves the full chain:
|
||||
```
|
||||
_sip._udp.carrier.com → sip1.carrier.com, sip2.carrier.com → 203.0.113.10, 203.0.113.11
|
||||
```
|
||||
|
||||
**Admin API Endpoints:**
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| `GET` | `/api/sip-guardian/dns-whitelist` | List all resolved DNS entries |
|
||||
| `POST` | `/api/sip-guardian/dns-whitelist/refresh` | Force immediate DNS refresh |
|
||||
|
||||
---
|
||||
|
||||
### SIP Message Validation
|
||||
|
||||
Enforces RFC 3261 compliance and blocks malformed/malicious packets:
|
||||
|
||||
48
admin.go
48
admin.go
@ -57,6 +57,10 @@ func (h *AdminHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next ca
|
||||
return h.handleBans(w, r)
|
||||
case strings.HasSuffix(path, "/stats"):
|
||||
return h.handleStats(w, r)
|
||||
case strings.HasSuffix(path, "/dns-whitelist"):
|
||||
return h.handleDNSWhitelist(w, r)
|
||||
case strings.HasSuffix(path, "/dns-whitelist/refresh"):
|
||||
return h.handleDNSWhitelistRefresh(w, r)
|
||||
case strings.Contains(path, "/unban/"):
|
||||
return h.handleUnban(w, r, path)
|
||||
case strings.Contains(path, "/ban/"):
|
||||
@ -131,6 +135,50 @@ func (h *AdminHandler) handleUnban(w http.ResponseWriter, r *http.Request, path
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleDNSWhitelist returns DNS whitelist entries and stats
|
||||
func (h *AdminHandler) handleDNSWhitelist(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
entries := h.guardian.GetDNSWhitelistEntries()
|
||||
stats := h.guardian.GetStats()
|
||||
|
||||
response := map[string]interface{}{
|
||||
"entries": entries,
|
||||
"count": len(entries),
|
||||
}
|
||||
|
||||
// Add DNS-specific stats if available
|
||||
if dnsStats, ok := stats["dns_whitelist"]; ok {
|
||||
response["stats"] = dnsStats
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(response)
|
||||
}
|
||||
|
||||
// handleDNSWhitelistRefresh forces an immediate DNS refresh
|
||||
func (h *AdminHandler) handleDNSWhitelistRefresh(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return nil
|
||||
}
|
||||
|
||||
h.guardian.RefreshDNSWhitelist()
|
||||
|
||||
// Get updated entries
|
||||
entries := h.guardian.GetDNSWhitelistEntries()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"success": true,
|
||||
"message": "DNS whitelist refreshed",
|
||||
"count": len(entries),
|
||||
})
|
||||
}
|
||||
|
||||
// handleBan manually adds an IP to the ban list
|
||||
func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path string) error {
|
||||
if r.Method != http.MethodPost {
|
||||
|
||||
382
dns_whitelist.go
Normal file
382
dns_whitelist.go
Normal file
@ -0,0 +1,382 @@
|
||||
// Package sipguardian provides DNS-aware whitelist functionality for SIP Guardian.
|
||||
// This allows whitelisting by hostname, A/AAAA records, and SRV records.
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DNSWhitelistConfig holds configuration for DNS-based whitelisting
|
||||
type DNSWhitelistConfig struct {
|
||||
// Hostnames to resolve and whitelist (A/AAAA records)
|
||||
Hostnames []string `json:"hostnames,omitempty"`
|
||||
|
||||
// SRV records to resolve (e.g., "_sip._udp.provider.com")
|
||||
// Will resolve SRV -> hostnames -> IPs
|
||||
SRVRecords []string `json:"srv_records,omitempty"`
|
||||
|
||||
// RefreshInterval for DNS lookups (default: 5m)
|
||||
RefreshInterval time.Duration `json:"refresh_interval,omitempty"`
|
||||
|
||||
// AllowStale allows using cached IPs if DNS refresh fails (default: true)
|
||||
AllowStale bool `json:"allow_stale,omitempty"`
|
||||
|
||||
// ResolveTimeout for individual DNS queries (default: 10s)
|
||||
ResolveTimeout time.Duration `json:"resolve_timeout,omitempty"`
|
||||
}
|
||||
|
||||
// DNSWhitelist manages DNS-based IP whitelisting with automatic refresh
|
||||
type DNSWhitelist struct {
|
||||
config DNSWhitelistConfig
|
||||
logger *zap.Logger
|
||||
|
||||
// Resolved IPs with their source info
|
||||
resolvedIPs map[string]*ResolvedEntry
|
||||
mu sync.RWMutex
|
||||
|
||||
// For graceful shutdown
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// ResolvedEntry tracks where an IP came from
|
||||
type ResolvedEntry struct {
|
||||
IP string `json:"ip"`
|
||||
Source string `json:"source"` // hostname or SRV record
|
||||
SourceType string `json:"source_type"` // "hostname", "srv"
|
||||
ResolvedAt time.Time `json:"resolved_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
TTL int `json:"ttl"` // DNS TTL in seconds
|
||||
}
|
||||
|
||||
// NewDNSWhitelist creates a new DNS whitelist manager
|
||||
func NewDNSWhitelist(config DNSWhitelistConfig, logger *zap.Logger) *DNSWhitelist {
|
||||
// Set defaults
|
||||
if config.RefreshInterval == 0 {
|
||||
config.RefreshInterval = 5 * time.Minute
|
||||
}
|
||||
if config.ResolveTimeout == 0 {
|
||||
config.ResolveTimeout = 10 * time.Second
|
||||
}
|
||||
if !config.AllowStale {
|
||||
config.AllowStale = true // Default to allowing stale on DNS failure
|
||||
}
|
||||
|
||||
return &DNSWhitelist{
|
||||
config: config,
|
||||
logger: logger,
|
||||
resolvedIPs: make(map[string]*ResolvedEntry),
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins the DNS refresh loop
|
||||
func (d *DNSWhitelist) Start() error {
|
||||
// Do initial resolution
|
||||
d.refreshAll()
|
||||
|
||||
// Start background refresh
|
||||
d.wg.Add(1)
|
||||
go d.refreshLoop()
|
||||
|
||||
d.logger.Info("DNS whitelist started",
|
||||
zap.Int("hostnames", len(d.config.Hostnames)),
|
||||
zap.Int("srv_records", len(d.config.SRVRecords)),
|
||||
zap.Duration("refresh_interval", d.config.RefreshInterval),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS refresh loop
|
||||
func (d *DNSWhitelist) Stop() {
|
||||
close(d.stopCh)
|
||||
d.wg.Wait()
|
||||
}
|
||||
|
||||
// Contains checks if an IP is in the DNS whitelist
|
||||
func (d *DNSWhitelist) Contains(ip string) bool {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
entry, exists := d.resolvedIPs[ip]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if entry is still valid
|
||||
if time.Now().After(entry.ExpiresAt) && !d.config.AllowStale {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// GetSource returns the source info for a whitelisted IP
|
||||
func (d *DNSWhitelist) GetSource(ip string) *ResolvedEntry {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
if entry, exists := d.resolvedIPs[ip]; exists {
|
||||
// Return copy to avoid race conditions
|
||||
entryCopy := *entry
|
||||
return &entryCopy
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetResolvedIPs returns all currently resolved IPs
|
||||
func (d *DNSWhitelist) GetResolvedIPs() []ResolvedEntry {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
entries := make([]ResolvedEntry, 0, len(d.resolvedIPs))
|
||||
for _, entry := range d.resolvedIPs {
|
||||
entries = append(entries, *entry)
|
||||
}
|
||||
return entries
|
||||
}
|
||||
|
||||
// refreshLoop periodically refreshes DNS entries
|
||||
func (d *DNSWhitelist) refreshLoop() {
|
||||
defer d.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(d.config.RefreshInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.stopCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
d.refreshAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshAll resolves all configured hostnames and SRV records
|
||||
func (d *DNSWhitelist) refreshAll() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), d.config.ResolveTimeout*time.Duration(len(d.config.Hostnames)+len(d.config.SRVRecords)+1))
|
||||
defer cancel()
|
||||
|
||||
newResolved := make(map[string]*ResolvedEntry)
|
||||
now := time.Now()
|
||||
defaultExpiry := now.Add(d.config.RefreshInterval * 2) // 2x refresh interval as safety margin
|
||||
|
||||
// Resolve hostnames (A/AAAA records)
|
||||
for _, hostname := range d.config.Hostnames {
|
||||
ips, err := d.resolveHostname(ctx, hostname)
|
||||
if err != nil {
|
||||
d.logger.Warn("Failed to resolve hostname",
|
||||
zap.String("hostname", hostname),
|
||||
zap.Error(err),
|
||||
)
|
||||
// Keep existing entries if AllowStale
|
||||
if d.config.AllowStale {
|
||||
d.copyExistingEntries(newResolved, hostname, "hostname")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
newResolved[ip] = &ResolvedEntry{
|
||||
IP: ip,
|
||||
Source: hostname,
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: defaultExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
d.logger.Debug("Resolved hostname",
|
||||
zap.String("hostname", hostname),
|
||||
zap.Strings("ips", ips),
|
||||
)
|
||||
}
|
||||
|
||||
// Resolve SRV records
|
||||
for _, srv := range d.config.SRVRecords {
|
||||
ips, targets, err := d.resolveSRV(ctx, srv)
|
||||
if err != nil {
|
||||
d.logger.Warn("Failed to resolve SRV record",
|
||||
zap.String("srv", srv),
|
||||
zap.Error(err),
|
||||
)
|
||||
// Keep existing entries if AllowStale
|
||||
if d.config.AllowStale {
|
||||
d.copyExistingEntries(newResolved, srv, "srv")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
newResolved[ip] = &ResolvedEntry{
|
||||
IP: ip,
|
||||
Source: srv,
|
||||
SourceType: "srv",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: defaultExpiry,
|
||||
}
|
||||
}
|
||||
|
||||
d.logger.Debug("Resolved SRV record",
|
||||
zap.String("srv", srv),
|
||||
zap.Strings("targets", targets),
|
||||
zap.Strings("ips", ips),
|
||||
)
|
||||
}
|
||||
|
||||
// Update the map atomically
|
||||
d.mu.Lock()
|
||||
d.resolvedIPs = newResolved
|
||||
d.mu.Unlock()
|
||||
|
||||
d.logger.Info("DNS whitelist refreshed",
|
||||
zap.Int("total_ips", len(newResolved)),
|
||||
)
|
||||
}
|
||||
|
||||
// copyExistingEntries copies existing entries for a source to the new map
|
||||
func (d *DNSWhitelist) copyExistingEntries(newMap map[string]*ResolvedEntry, source, sourceType string) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
for ip, entry := range d.resolvedIPs {
|
||||
if entry.Source == source && entry.SourceType == sourceType {
|
||||
newMap[ip] = entry
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveHostname resolves a hostname to IP addresses
|
||||
func (d *DNSWhitelist) resolveHostname(ctx context.Context, hostname string) ([]string, error) {
|
||||
// Handle case where hostname is already an IP
|
||||
if ip := net.ParseIP(hostname); ip != nil {
|
||||
return []string{hostname}, nil
|
||||
}
|
||||
|
||||
resolver := net.DefaultResolver
|
||||
addrs, err := resolver.LookupIPAddr(ctx, hostname)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ips := make([]string, 0, len(addrs))
|
||||
for _, addr := range addrs {
|
||||
ips = append(ips, addr.IP.String())
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
// resolveSRV resolves an SRV record and all its targets
|
||||
func (d *DNSWhitelist) resolveSRV(ctx context.Context, srvRecord string) (ips []string, targets []string, err error) {
|
||||
// Parse SRV record format: _service._proto.name or just name
|
||||
// Example: _sip._udp.provider.com or _sip._tcp.provider.com
|
||||
var service, proto, name string
|
||||
|
||||
parts := strings.Split(srvRecord, ".")
|
||||
if len(parts) >= 3 && strings.HasPrefix(parts[0], "_") && strings.HasPrefix(parts[1], "_") {
|
||||
service = strings.TrimPrefix(parts[0], "_")
|
||||
proto = strings.TrimPrefix(parts[1], "_")
|
||||
name = strings.Join(parts[2:], ".")
|
||||
} else {
|
||||
// Assume it's a plain domain, try common SIP SRV patterns
|
||||
// Try _sip._udp first, then _sip._tcp
|
||||
service = "sip"
|
||||
proto = "udp"
|
||||
name = srvRecord
|
||||
}
|
||||
|
||||
resolver := net.DefaultResolver
|
||||
|
||||
// Try to resolve SRV
|
||||
_, srvRecords, err := resolver.LookupSRV(ctx, service, proto, name)
|
||||
if err != nil {
|
||||
// If UDP fails, try TCP
|
||||
if proto == "udp" {
|
||||
_, srvRecords, err = resolver.LookupSRV(ctx, service, "tcp", name)
|
||||
if err != nil {
|
||||
// Fall back to A record lookup on the original name
|
||||
directIPs, aErr := d.resolveHostname(ctx, srvRecord)
|
||||
if aErr != nil {
|
||||
return nil, nil, err // Return original SRV error
|
||||
}
|
||||
return directIPs, []string{srvRecord}, nil
|
||||
}
|
||||
} else {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Resolve each SRV target to IPs
|
||||
seenIPs := make(map[string]bool)
|
||||
for _, srv := range srvRecords {
|
||||
target := strings.TrimSuffix(srv.Target, ".")
|
||||
targets = append(targets, target)
|
||||
|
||||
targetIPs, err := d.resolveHostname(ctx, target)
|
||||
if err != nil {
|
||||
d.logger.Warn("Failed to resolve SRV target",
|
||||
zap.String("srv", srvRecord),
|
||||
zap.String("target", target),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, ip := range targetIPs {
|
||||
if !seenIPs[ip] {
|
||||
seenIPs[ip] = true
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ips, targets, nil
|
||||
}
|
||||
|
||||
// ForceRefresh triggers an immediate refresh of DNS entries
|
||||
func (d *DNSWhitelist) ForceRefresh() {
|
||||
d.refreshAll()
|
||||
}
|
||||
|
||||
// Stats returns statistics about the DNS whitelist
|
||||
func (d *DNSWhitelist) Stats() map[string]interface{} {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
hostnameIPs := 0
|
||||
srvIPs := 0
|
||||
staleCount := 0
|
||||
now := time.Now()
|
||||
|
||||
for _, entry := range d.resolvedIPs {
|
||||
switch entry.SourceType {
|
||||
case "hostname":
|
||||
hostnameIPs++
|
||||
case "srv":
|
||||
srvIPs++
|
||||
}
|
||||
if now.After(entry.ExpiresAt) {
|
||||
staleCount++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_ips": len(d.resolvedIPs),
|
||||
"hostname_ips": hostnameIPs,
|
||||
"srv_ips": srvIPs,
|
||||
"stale_count": staleCount,
|
||||
"configured_hosts": len(d.config.Hostnames),
|
||||
"configured_srv": len(d.config.SRVRecords),
|
||||
"refresh_interval": d.config.RefreshInterval.String(),
|
||||
"allow_stale": d.config.AllowStale,
|
||||
}
|
||||
}
|
||||
500
dns_whitelist_test.go
Normal file
500
dns_whitelist_test.go
Normal file
@ -0,0 +1,500 @@
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestDNSWhitelistConfig tests configuration defaults and validation
|
||||
func TestDNSWhitelistConfig(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
config DNSWhitelistConfig
|
||||
expectedRefresh time.Duration
|
||||
expectedTimeout time.Duration
|
||||
}{
|
||||
{
|
||||
name: "empty config uses defaults",
|
||||
config: DNSWhitelistConfig{},
|
||||
expectedRefresh: 5 * time.Minute,
|
||||
expectedTimeout: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "custom refresh interval",
|
||||
config: DNSWhitelistConfig{
|
||||
RefreshInterval: 10 * time.Minute,
|
||||
},
|
||||
expectedRefresh: 10 * time.Minute,
|
||||
expectedTimeout: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "custom timeout",
|
||||
config: DNSWhitelistConfig{
|
||||
ResolveTimeout: 30 * time.Second,
|
||||
},
|
||||
expectedRefresh: 5 * time.Minute,
|
||||
expectedTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
wl := NewDNSWhitelist(tt.config, logger)
|
||||
|
||||
if wl.config.RefreshInterval != tt.expectedRefresh {
|
||||
t.Errorf("RefreshInterval = %v, want %v", wl.config.RefreshInterval, tt.expectedRefresh)
|
||||
}
|
||||
if wl.config.ResolveTimeout != tt.expectedTimeout {
|
||||
t.Errorf("ResolveTimeout = %v, want %v", wl.config.ResolveTimeout, tt.expectedTimeout)
|
||||
}
|
||||
if !wl.config.AllowStale {
|
||||
t.Error("AllowStale should default to true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistContains tests the IP lookup functionality
|
||||
func TestDNSWhitelistContains(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||
|
||||
// Manually add some entries for testing
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "test.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||
IP: "10.0.0.50",
|
||||
Source: "_sip._udp.provider.com",
|
||||
SourceType: "srv",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
{"192.168.1.100", true},
|
||||
{"10.0.0.50", true},
|
||||
{"8.8.8.8", false},
|
||||
{"192.168.1.101", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.ip, func(t *testing.T) {
|
||||
if got := wl.Contains(tt.ip); got != tt.expected {
|
||||
t.Errorf("Contains(%s) = %v, want %v", tt.ip, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistExpiredEntries tests handling of expired entries
|
||||
func TestDNSWhitelistExpiredEntries(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
t.Run("expired entry rejected when AllowStale=false", func(t *testing.T) {
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
AllowStale: false,
|
||||
}, logger)
|
||||
// Override the default
|
||||
wl.config.AllowStale = false
|
||||
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "test.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now.Add(-1 * time.Hour),
|
||||
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
||||
}
|
||||
|
||||
if wl.Contains("192.168.1.100") {
|
||||
t.Error("Expired entry should not match when AllowStale=false")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("expired entry allowed when AllowStale=true", func(t *testing.T) {
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
AllowStale: true,
|
||||
}, logger)
|
||||
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "test.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now.Add(-1 * time.Hour),
|
||||
ExpiresAt: now.Add(-30 * time.Minute), // Already expired
|
||||
}
|
||||
|
||||
if !wl.Contains("192.168.1.100") {
|
||||
t.Error("Expired entry should match when AllowStale=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDNSWhitelistGetSource tests source info retrieval
|
||||
func TestDNSWhitelistGetSource(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "pbx.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
|
||||
t.Run("existing IP returns source", func(t *testing.T) {
|
||||
source := wl.GetSource("192.168.1.100")
|
||||
if source == nil {
|
||||
t.Fatal("GetSource returned nil for existing IP")
|
||||
}
|
||||
if source.Source != "pbx.example.com" {
|
||||
t.Errorf("Source = %s, want pbx.example.com", source.Source)
|
||||
}
|
||||
if source.SourceType != "hostname" {
|
||||
t.Errorf("SourceType = %s, want hostname", source.SourceType)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existent IP returns nil", func(t *testing.T) {
|
||||
source := wl.GetSource("8.8.8.8")
|
||||
if source != nil {
|
||||
t.Error("GetSource should return nil for non-existent IP")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDNSWhitelistGetResolvedIPs tests listing all entries
|
||||
func TestDNSWhitelistGetResolvedIPs(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||
|
||||
// Empty whitelist
|
||||
entries := wl.GetResolvedIPs()
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("Empty whitelist should return 0 entries, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Add entries
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "host1.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||
IP: "10.0.0.50",
|
||||
Source: "_sip._udp.provider.com",
|
||||
SourceType: "srv",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
|
||||
entries = wl.GetResolvedIPs()
|
||||
if len(entries) != 2 {
|
||||
t.Errorf("Expected 2 entries, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistStats tests statistics reporting
|
||||
func TestDNSWhitelistStats(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: []string{"host1.example.com", "host2.example.com"},
|
||||
SRVRecords: []string{"_sip._udp.provider.com"},
|
||||
RefreshInterval: 10 * time.Minute,
|
||||
}, logger)
|
||||
|
||||
now := time.Now()
|
||||
wl.resolvedIPs["192.168.1.100"] = &ResolvedEntry{
|
||||
IP: "192.168.1.100",
|
||||
Source: "host1.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
wl.resolvedIPs["192.168.1.101"] = &ResolvedEntry{
|
||||
IP: "192.168.1.101",
|
||||
Source: "host2.example.com",
|
||||
SourceType: "hostname",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
wl.resolvedIPs["10.0.0.50"] = &ResolvedEntry{
|
||||
IP: "10.0.0.50",
|
||||
Source: "_sip._udp.provider.com",
|
||||
SourceType: "srv",
|
||||
ResolvedAt: now,
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
|
||||
stats := wl.Stats()
|
||||
|
||||
if stats["total_ips"] != 3 {
|
||||
t.Errorf("total_ips = %v, want 3", stats["total_ips"])
|
||||
}
|
||||
if stats["hostname_ips"] != 2 {
|
||||
t.Errorf("hostname_ips = %v, want 2", stats["hostname_ips"])
|
||||
}
|
||||
if stats["srv_ips"] != 1 {
|
||||
t.Errorf("srv_ips = %v, want 1", stats["srv_ips"])
|
||||
}
|
||||
if stats["configured_hosts"] != 2 {
|
||||
t.Errorf("configured_hosts = %v, want 2", stats["configured_hosts"])
|
||||
}
|
||||
if stats["configured_srv"] != 1 {
|
||||
t.Errorf("configured_srv = %v, want 1", stats["configured_srv"])
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistResolveHostname tests hostname resolution with direct IPs
|
||||
func TestDNSWhitelistResolveHostnameDirectIP(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||
|
||||
// Test that direct IP addresses are handled correctly
|
||||
tests := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"192.168.1.100", "192.168.1.100"},
|
||||
{"10.0.0.1", "10.0.0.1"},
|
||||
{"::1", "::1"},
|
||||
{"2001:db8::1", "2001:db8::1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
ips, err := wl.resolveHostname(nil, tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("resolveHostname(%s) error: %v", tt.input, err)
|
||||
}
|
||||
if len(ips) != 1 || ips[0] != tt.expected {
|
||||
t.Errorf("resolveHostname(%s) = %v, want [%s]", tt.input, ips, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistRealDNS tests actual DNS resolution (integration test)
|
||||
// This test requires network connectivity
|
||||
func TestDNSWhitelistRealDNS(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping DNS integration test in short mode")
|
||||
}
|
||||
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: []string{"localhost"},
|
||||
RefreshInterval: 1 * time.Minute,
|
||||
ResolveTimeout: 5 * time.Second,
|
||||
}, logger)
|
||||
|
||||
// Start and let it do initial resolution
|
||||
if err := wl.Start(); err != nil {
|
||||
t.Fatalf("Failed to start DNS whitelist: %v", err)
|
||||
}
|
||||
defer wl.Stop()
|
||||
|
||||
// Give it a moment to resolve
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// localhost should resolve to 127.0.0.1 or ::1
|
||||
if !wl.Contains("127.0.0.1") && !wl.Contains("::1") {
|
||||
entries := wl.GetResolvedIPs()
|
||||
t.Errorf("Expected localhost to resolve, got entries: %+v", entries)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistStartStop tests the lifecycle management
|
||||
func TestDNSWhitelistStartStop(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: []string{"127.0.0.1"}, // Use IP to avoid DNS
|
||||
RefreshInterval: 100 * time.Millisecond,
|
||||
}, logger)
|
||||
|
||||
// Start
|
||||
if err := wl.Start(); err != nil {
|
||||
t.Fatalf("Start() error: %v", err)
|
||||
}
|
||||
|
||||
// Should have the IP immediately
|
||||
if !wl.Contains("127.0.0.1") {
|
||||
t.Error("Should contain 127.0.0.1 after start")
|
||||
}
|
||||
|
||||
// Stop should complete without hanging
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wl.Stop()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Good, stopped cleanly
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("Stop() took too long")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistForceRefresh tests the manual refresh functionality
|
||||
func TestDNSWhitelistForceRefresh(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: []string{"127.0.0.1"},
|
||||
RefreshInterval: 1 * time.Hour, // Long interval
|
||||
}, logger)
|
||||
|
||||
// Don't start the refresh loop, just manually refresh
|
||||
wl.ForceRefresh()
|
||||
|
||||
if !wl.Contains("127.0.0.1") {
|
||||
t.Error("ForceRefresh should resolve IPs")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDNSWhitelistConcurrency tests concurrent access to the whitelist
|
||||
func TestDNSWhitelistConcurrency(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: []string{"127.0.0.1", "127.0.0.2", "127.0.0.3"},
|
||||
RefreshInterval: 50 * time.Millisecond,
|
||||
}, logger)
|
||||
|
||||
if err := wl.Start(); err != nil {
|
||||
t.Fatalf("Start() error: %v", err)
|
||||
}
|
||||
defer wl.Stop()
|
||||
|
||||
// Run concurrent reads while refresh is happening
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
for i := 0; i < 1000; i++ {
|
||||
wl.Contains("127.0.0.1")
|
||||
wl.GetResolvedIPs()
|
||||
wl.Stats()
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Good
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("Concurrent operations took too long")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSIPGuardianDNSWhitelistIntegration tests integration with SIPGuardian
|
||||
func TestSIPGuardianDNSWhitelistIntegration(t *testing.T) {
|
||||
// Create a SIPGuardian with DNS whitelist config
|
||||
g := &SIPGuardian{
|
||||
WhitelistHosts: []string{"127.0.0.1"},
|
||||
WhitelistSRV: []string{},
|
||||
}
|
||||
|
||||
// Initialize maps
|
||||
g.bannedIPs = make(map[string]*BanEntry)
|
||||
g.failureCounts = make(map[string]*failureTracker)
|
||||
g.logger = zap.NewNop()
|
||||
|
||||
// Manually create DNS whitelist (normally done in Provision)
|
||||
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: g.WhitelistHosts,
|
||||
}, g.logger)
|
||||
g.dnsWhitelist.ForceRefresh()
|
||||
|
||||
// Test IsWhitelisted
|
||||
if !g.IsWhitelisted("127.0.0.1") {
|
||||
t.Error("127.0.0.1 should be whitelisted via DNS")
|
||||
}
|
||||
|
||||
if g.IsWhitelisted("8.8.8.8") {
|
||||
t.Error("8.8.8.8 should NOT be whitelisted")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSIPGuardianMixedWhitelist tests CIDR + DNS whitelist together
|
||||
func TestSIPGuardianMixedWhitelist(t *testing.T) {
|
||||
g := &SIPGuardian{
|
||||
WhitelistCIDR: []string{"10.0.0.0/8"},
|
||||
WhitelistHosts: []string{"127.0.0.1"},
|
||||
}
|
||||
|
||||
// Initialize
|
||||
g.bannedIPs = make(map[string]*BanEntry)
|
||||
g.failureCounts = make(map[string]*failureTracker)
|
||||
g.logger = zap.NewNop()
|
||||
|
||||
// Parse CIDR whitelist (normally done in Provision)
|
||||
for _, cidr := range g.WhitelistCIDR {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err == nil {
|
||||
g.whitelistNets = append(g.whitelistNets, network)
|
||||
}
|
||||
}
|
||||
|
||||
// Set up DNS whitelist
|
||||
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: g.WhitelistHosts,
|
||||
}, g.logger)
|
||||
g.dnsWhitelist.ForceRefresh()
|
||||
|
||||
// Test CIDR whitelist
|
||||
if !g.IsWhitelisted("10.1.2.3") {
|
||||
t.Error("10.1.2.3 should be whitelisted via CIDR")
|
||||
}
|
||||
|
||||
// Test DNS whitelist
|
||||
if !g.IsWhitelisted("127.0.0.1") {
|
||||
t.Error("127.0.0.1 should be whitelisted via DNS")
|
||||
}
|
||||
|
||||
// Test non-whitelisted
|
||||
if g.IsWhitelisted("8.8.8.8") {
|
||||
t.Error("8.8.8.8 should NOT be whitelisted")
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDNSWhitelistContains benchmarks lookup performance
|
||||
func BenchmarkDNSWhitelistContains(b *testing.B) {
|
||||
logger := zap.NewNop()
|
||||
wl := NewDNSWhitelist(DNSWhitelistConfig{}, logger)
|
||||
|
||||
// Add 1000 entries
|
||||
now := time.Now()
|
||||
for i := 0; i < 1000; i++ {
|
||||
ip := "192.168." + string(rune('0'+i/256)) + "." + string(rune('0'+i%256))
|
||||
wl.resolvedIPs[ip] = &ResolvedEntry{
|
||||
IP: ip,
|
||||
Source: "test.example.com",
|
||||
ExpiresAt: now.Add(10 * time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
wl.Contains("192.168.0.100")
|
||||
}
|
||||
}
|
||||
@ -322,6 +322,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
|
||||
}
|
||||
|
||||
// suspiciousPatternDefs defines patterns and their names for detection
|
||||
// IMPORTANT: Patterns must be specific enough to avoid false positives on legitimate traffic
|
||||
var suspiciousPatternDefs = []struct {
|
||||
name string
|
||||
pattern string
|
||||
@ -331,12 +332,14 @@ var suspiciousPatternDefs = []struct {
|
||||
{"sipcli", "sipcli"},
|
||||
{"sip-scan", "sip-scan"},
|
||||
{"voipbuster", "voipbuster"},
|
||||
{"asterisk-pbx-scanner", "asterisk pbx"},
|
||||
// Note: "asterisk pbx scanner" pattern removed - too broad, catches legitimate Asterisk PBX systems
|
||||
// The original pattern "asterisk pbx" would match "User-Agent: Asterisk PBX 18.0" which is legitimate
|
||||
{"sipsak", "sipsak"},
|
||||
{"sundayddr", "sundayddr"},
|
||||
{"iwar", "iwar"},
|
||||
{"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood
|
||||
{"zoiper-spoof", "user-agent: zoiper"},
|
||||
// Note: "cseq: 1 options" pattern REMOVED - too broad, catches ANY first OPTIONS request
|
||||
// OPTIONS with CSeq 1 is completely normal - it's the first OPTIONS from any client
|
||||
// Use rate limiting for OPTIONS flood detection instead
|
||||
{"test-extension-100", "sip:100@"},
|
||||
{"test-extension-1000", "sip:1000@"},
|
||||
{"null-user", "sip:@"},
|
||||
|
||||
596
l4handler_test.go
Normal file
596
l4handler_test.go
Normal file
@ -0,0 +1,596 @@
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// SIP Matcher Tests - Verifying SIP traffic is correctly identified
|
||||
// =============================================================================
|
||||
|
||||
// provisionMatcherForTest creates a SIPMatcher with default methods without requiring Caddy context
|
||||
func provisionMatcherForTest(methods []string) *SIPMatcher {
|
||||
if len(methods) == 0 {
|
||||
methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
|
||||
}
|
||||
pattern := "^(" + strings.Join(methods, "|") + ") sip:"
|
||||
return &SIPMatcher{
|
||||
Methods: methods,
|
||||
methodRegex: regexp.MustCompile("(?i)" + pattern),
|
||||
}
|
||||
}
|
||||
|
||||
func TestSIPMethodPatternMatching(t *testing.T) {
|
||||
// Create a provisioned matcher using our test helper
|
||||
m := provisionMatcherForTest(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expected bool
|
||||
}{
|
||||
// Legitimate SIP requests - MUST match
|
||||
{
|
||||
name: "REGISTER request",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "INVITE request",
|
||||
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "OPTIONS request",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ACK request",
|
||||
data: []byte("ACK sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "BYE request",
|
||||
data: []byte("BYE sip:alice@192.168.1.100 SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "CANCEL request",
|
||||
data: []byte("CANCEL sip:bob@pbx.local SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "INFO request",
|
||||
data: []byte("INFO sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "NOTIFY request",
|
||||
data: []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SUBSCRIBE request",
|
||||
data: []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "MESSAGE request",
|
||||
data: []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
// Case insensitivity
|
||||
{
|
||||
name: "lowercase register",
|
||||
data: []byte("register sip:example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "mixed case INVITE",
|
||||
data: []byte("Invite sip:alice@example.com SIP/2.0\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
// SIP responses
|
||||
{
|
||||
name: "SIP 200 OK response",
|
||||
data: []byte("SIP/2.0 200 OK\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SIP 100 Trying response",
|
||||
data: []byte("SIP/2.0 100 Trying\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SIP 180 Ringing response",
|
||||
data: []byte("SIP/2.0 180 Ringing\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SIP 401 Unauthorized response",
|
||||
data: []byte("SIP/2.0 401 Unauthorized\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "SIP 486 Busy Here response",
|
||||
data: []byte("SIP/2.0 486 Busy Here\r\n"),
|
||||
expected: true,
|
||||
},
|
||||
// Non-SIP traffic - MUST NOT match (should be passed through or rejected elsewhere)
|
||||
{
|
||||
name: "HTTP GET request",
|
||||
data: []byte("GET / HTTP/1.1\r\n"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "HTTP POST request",
|
||||
data: []byte("POST /api HTTP/1.1\r\n"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "SMTP EHLO",
|
||||
data: []byte("EHLO mail.example.com\r\n"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "random binary data",
|
||||
data: []byte{0x00, 0x01, 0x02, 0x03, 0x04},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "RTP-like packet",
|
||||
data: []byte{0x80, 0x00, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matches := m.methodRegex.Match(tt.data) || bytes.HasPrefix(tt.data, []byte("SIP/2.0"))
|
||||
if matches != tt.expected {
|
||||
t.Errorf("SIP pattern match for %q: got %v, want %v", tt.name, matches, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatcherDefaultMethods verifies the matcher provisions with correct default methods
|
||||
func TestMatcherDefaultMethods(t *testing.T) {
|
||||
m := provisionMatcherForTest(nil)
|
||||
|
||||
expectedMethods := []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
|
||||
|
||||
if len(m.Methods) != len(expectedMethods) {
|
||||
t.Errorf("Default methods count: got %d, want %d", len(m.Methods), len(expectedMethods))
|
||||
}
|
||||
|
||||
for _, method := range expectedMethods {
|
||||
found := false
|
||||
for _, m := range m.Methods {
|
||||
if m == method {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Expected method %s not in default methods", method)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatcherCustomMethods verifies custom method configuration works
|
||||
func TestMatcherCustomMethods(t *testing.T) {
|
||||
m := provisionMatcherForTest([]string{"REGISTER", "INVITE"})
|
||||
|
||||
// Should match REGISTER
|
||||
if !m.methodRegex.Match([]byte("REGISTER sip:example.com SIP/2.0\r\n")) {
|
||||
t.Error("Should match REGISTER when configured")
|
||||
}
|
||||
|
||||
// Should match INVITE
|
||||
if !m.methodRegex.Match([]byte("INVITE sip:alice@example.com SIP/2.0\r\n")) {
|
||||
t.Error("Should match INVITE when configured")
|
||||
}
|
||||
|
||||
// Should NOT match OPTIONS (not in our custom list)
|
||||
if m.methodRegex.Match([]byte("OPTIONS sip:example.com SIP/2.0\r\n")) {
|
||||
t.Error("Should NOT match OPTIONS when not in custom methods")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Suspicious Pattern Detection Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestDetectSuspiciousPattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectDetection bool
|
||||
expectedPattern string
|
||||
}{
|
||||
// Known attack tools - MUST be detected
|
||||
{
|
||||
name: "SIPVicious User-Agent",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: friendly-scanner\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "friendly-scanner",
|
||||
},
|
||||
{
|
||||
name: "SIPVicious lowercase",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipvicious\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "sipvicious",
|
||||
},
|
||||
{
|
||||
name: "sipcli scanner",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sipcli/1.0\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "sipcli",
|
||||
},
|
||||
{
|
||||
name: "sipsak tool",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\nUser-Agent: sipsak 0.9.7\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "sipsak",
|
||||
},
|
||||
{
|
||||
name: "VoIPBuster",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: voipbuster\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "voipbuster",
|
||||
},
|
||||
{
|
||||
name: "sundayddr scanner",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: sundayddr\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "sundayddr",
|
||||
},
|
||||
{
|
||||
name: "iwar dialer",
|
||||
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: iwar/0.1\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "iwar",
|
||||
},
|
||||
// Common enumeration patterns
|
||||
{
|
||||
name: "test extension 100",
|
||||
data: []byte("REGISTER sip:100@example.com SIP/2.0\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "test-extension-100",
|
||||
},
|
||||
{
|
||||
name: "test extension 1000",
|
||||
data: []byte("REGISTER sip:1000@example.com SIP/2.0\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "test-extension-1000",
|
||||
},
|
||||
{
|
||||
name: "null user probe",
|
||||
data: []byte("REGISTER sip:@example.com SIP/2.0\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "null-user",
|
||||
},
|
||||
{
|
||||
name: "anonymous caller",
|
||||
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nFrom: <sip:anonymous@anonymous.invalid>\r\n"),
|
||||
expectDetection: true,
|
||||
expectedPattern: "anonymous",
|
||||
},
|
||||
// LEGITIMATE traffic - MUST NOT be detected as suspicious
|
||||
{
|
||||
name: "Zoiper softphone",
|
||||
// Zoiper is a legitimate softphone - pattern removed to avoid false positives
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Zoiper rv2.0.18\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Linphone client",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Linphone/4.5.0\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Asterisk PBX",
|
||||
// The "asterisk pbx" pattern was removed as it caused false positives
|
||||
// Legitimate Asterisk servers now pass through correctly
|
||||
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\nUser-Agent: Asterisk PBX 18.0\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "FreeSWITCH",
|
||||
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\nUser-Agent: FreeSWITCH\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Grandstream phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Grandstream GXP2170 1.0.11.10\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Polycom phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: PolycomVVX-VVX_410-UA/5.9.3\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Yealink phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Cisco phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Cisco-SIPIPCommunicator/9.1.1\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Avaya phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: Avaya one-X Communicator\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "3CX Softphone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\nUser-Agent: 3CXPhone 6.0\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "Twilio gateway",
|
||||
data: []byte("INVITE sip:+15551234567@example.com SIP/2.0\r\nUser-Agent: twilio-client/2.0\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "regular extension 5001",
|
||||
data: []byte("REGISTER sip:5001@example.com SIP/2.0\r\nUser-Agent: Linphone\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
{
|
||||
name: "regular extension 1234",
|
||||
data: []byte("INVITE sip:1234@pbx.local SIP/2.0\r\n"),
|
||||
expectDetection: false,
|
||||
expectedPattern: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pattern := detectSuspiciousPattern(tt.data)
|
||||
detected := pattern != ""
|
||||
|
||||
if detected != tt.expectDetection {
|
||||
t.Errorf("Detection for %q: got detected=%v, want detected=%v (pattern=%q)",
|
||||
tt.name, detected, tt.expectDetection, pattern)
|
||||
}
|
||||
|
||||
if tt.expectDetection && pattern != tt.expectedPattern {
|
||||
t.Errorf("Pattern for %q: got %q, want %q", tt.name, pattern, tt.expectedPattern)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLegacyIsSuspiciousSIP verifies the legacy wrapper function
|
||||
func TestLegacyIsSuspiciousSIP(t *testing.T) {
|
||||
// Suspicious - should return true
|
||||
if !isSuspiciousSIP([]byte("User-Agent: friendly-scanner")) {
|
||||
t.Error("Should detect friendly-scanner as suspicious")
|
||||
}
|
||||
|
||||
// Not suspicious - should return false
|
||||
if isSuspiciousSIP([]byte("User-Agent: Linphone/4.5.0")) {
|
||||
t.Error("Should NOT detect Linphone as suspicious")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Complete SIP Message Tests - Real-world SIP traffic patterns
|
||||
// =============================================================================
|
||||
|
||||
func TestLegitimateREGISTERMessage(t *testing.T) {
|
||||
// A complete, legitimate REGISTER message from a typical SIP phone
|
||||
// Note: Using proper CRLF line endings as per SIP RFC 3261
|
||||
msg := []byte("REGISTER sip:1001@pbx.example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"From: \"John Smith\" <sip:1001@pbx.example.com>;tag=1\r\n" +
|
||||
"To: <sip:1001@pbx.example.com>\r\n" +
|
||||
"Call-ID: 1-1234@192.168.1.100\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Contact: <sip:1001@192.168.1.100:5060>\r\n" +
|
||||
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n" +
|
||||
"Expires: 3600\r\n" +
|
||||
"Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO\r\n" +
|
||||
"Content-Length: 0\r\n" +
|
||||
"\r\n")
|
||||
|
||||
// Should NOT be detected as suspicious
|
||||
pattern := detectSuspiciousPattern(msg)
|
||||
if pattern != "" {
|
||||
t.Errorf("Legitimate REGISTER should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||
}
|
||||
|
||||
// SIP method extraction should work
|
||||
method := ExtractSIPMethod(msg)
|
||||
if method != MethodREGISTER {
|
||||
t.Errorf("Method extraction: got %v, want REGISTER", method)
|
||||
}
|
||||
|
||||
// Extension extraction should work
|
||||
ext := ExtractTargetExtension(msg)
|
||||
if ext != "1001" {
|
||||
t.Errorf("Extension extraction: got %q, want 1001", ext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateINVITEMessage(t *testing.T) {
|
||||
// A complete, legitimate INVITE message for a call
|
||||
msg := []byte(`INVITE sip:5002@pbx.example.com SIP/2.0
|
||||
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-1234567
|
||||
Max-Forwards: 70
|
||||
From: "Alice" <sip:1001@pbx.example.com>;tag=abc123
|
||||
To: <sip:5002@pbx.example.com>
|
||||
Call-ID: call-8888@192.168.1.100
|
||||
CSeq: 1 INVITE
|
||||
Contact: <sip:1001@192.168.1.100:5060>
|
||||
User-Agent: Grandstream GXP2170 1.0.11.10
|
||||
Allow: INVITE, ACK, CANCEL, OPTIONS, BYE, REFER, NOTIFY, MESSAGE, SUBSCRIBE, INFO
|
||||
Content-Type: application/sdp
|
||||
Content-Length: 260
|
||||
|
||||
v=0
|
||||
o=- 1234567890 1234567890 IN IP4 192.168.1.100
|
||||
s=-
|
||||
c=IN IP4 192.168.1.100
|
||||
t=0 0
|
||||
m=audio 10000 RTP/AVP 0 8 101
|
||||
a=rtpmap:0 PCMU/8000
|
||||
a=rtpmap:8 PCMA/8000
|
||||
a=rtpmap:101 telephone-event/8000
|
||||
a=fmtp:101 0-16
|
||||
`)
|
||||
|
||||
// Should NOT be detected as suspicious
|
||||
pattern := detectSuspiciousPattern(msg)
|
||||
if pattern != "" {
|
||||
t.Errorf("Legitimate INVITE should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||
}
|
||||
|
||||
// SIP method extraction should work
|
||||
method := ExtractSIPMethod(msg)
|
||||
if method != MethodINVITE {
|
||||
t.Errorf("Method extraction: got %v, want INVITE", method)
|
||||
}
|
||||
|
||||
// Extension extraction should work
|
||||
ext := ExtractTargetExtension(msg)
|
||||
if ext != "5002" {
|
||||
t.Errorf("Extension extraction: got %q, want 5002", ext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateOPTIONSKeepAlive(t *testing.T) {
|
||||
// OPTIONS is commonly used for NAT keep-alive
|
||||
msg := []byte(`OPTIONS sip:pbx.example.com SIP/2.0
|
||||
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-ping-001
|
||||
Max-Forwards: 70
|
||||
From: <sip:1001@pbx.example.com>;tag=keepalive
|
||||
To: <sip:pbx.example.com>
|
||||
Call-ID: keepalive-12345@192.168.1.100
|
||||
CSeq: 100 OPTIONS
|
||||
User-Agent: Polycom/5.9.3
|
||||
Accept: application/sdp
|
||||
Content-Length: 0
|
||||
|
||||
`)
|
||||
|
||||
// Should NOT be detected as suspicious
|
||||
pattern := detectSuspiciousPattern(msg)
|
||||
if pattern != "" {
|
||||
t.Errorf("Legitimate OPTIONS keep-alive should NOT be flagged, got pattern: %s", pattern)
|
||||
}
|
||||
|
||||
method := ExtractSIPMethod(msg)
|
||||
if method != MethodOPTIONS {
|
||||
t.Errorf("Method extraction: got %v, want OPTIONS", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimate200OKResponse(t *testing.T) {
|
||||
// A 200 OK response to REGISTER
|
||||
msg := []byte(`SIP/2.0 200 OK
|
||||
Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK-524287-1-0;received=192.168.1.100
|
||||
From: "John Smith" <sip:1001@pbx.example.com>;tag=1
|
||||
To: <sip:1001@pbx.example.com>;tag=as1234
|
||||
Call-ID: 1-1234@192.168.1.100
|
||||
CSeq: 1 REGISTER
|
||||
Contact: <sip:1001@192.168.1.100:5060>;expires=3600
|
||||
Date: Mon, 01 Jan 2024 12:00:00 GMT
|
||||
Server: Asterisk PBX 18.0
|
||||
Content-Length: 0
|
||||
|
||||
`)
|
||||
|
||||
// Should NOT be detected as suspicious (Server: Asterisk is NOT "asterisk pbx" scanner signature)
|
||||
pattern := detectSuspiciousPattern(msg)
|
||||
if pattern != "" && pattern != "asterisk-pbx-scanner" {
|
||||
t.Errorf("200 OK response should NOT be flagged as suspicious, got pattern: %s", pattern)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Helper Function - min()
|
||||
// =============================================================================
|
||||
|
||||
func TestMinFunction(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b, expected int
|
||||
}{
|
||||
{1, 2, 1},
|
||||
{2, 1, 1},
|
||||
{5, 5, 5},
|
||||
{0, 10, 0},
|
||||
{-1, 1, -1},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
result := min(tt.a, tt.b)
|
||||
if result != tt.expected {
|
||||
t.Errorf("min(%d, %d) = %d, want %d", tt.a, tt.b, result, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SIPHandler Module Info Test
|
||||
// =============================================================================
|
||||
|
||||
func TestSIPHandlerModuleInfo(t *testing.T) {
|
||||
h := SIPHandler{}
|
||||
info := h.CaddyModule()
|
||||
|
||||
if info.ID != "layer4.handlers.sip_guardian" {
|
||||
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.handlers.sip_guardian")
|
||||
}
|
||||
|
||||
if info.New == nil {
|
||||
t.Error("Module New function should not be nil")
|
||||
}
|
||||
|
||||
// Verify New() returns correct type
|
||||
newModule := info.New()
|
||||
if _, ok := newModule.(*SIPHandler); !ok {
|
||||
t.Error("New() should return *SIPHandler")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSIPMatcherModuleInfo(t *testing.T) {
|
||||
m := SIPMatcher{}
|
||||
info := m.CaddyModule()
|
||||
|
||||
if info.ID != "layer4.matchers.sip" {
|
||||
t.Errorf("Module ID: got %q, want %q", info.ID, "layer4.matchers.sip")
|
||||
}
|
||||
|
||||
if info.New == nil {
|
||||
t.Error("Module New function should not be nil")
|
||||
}
|
||||
|
||||
// Verify New() returns correct type
|
||||
newModule := info.New()
|
||||
if _, ok := newModule.(*SIPMatcher); !ok {
|
||||
t.Error("New() should return *SIPMatcher")
|
||||
}
|
||||
}
|
||||
698
ratelimit_test.go
Normal file
698
ratelimit_test.go
Normal file
@ -0,0 +1,698 @@
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Rate Limiter Tests - Ensuring legitimate traffic flows through
|
||||
// =============================================================================
|
||||
|
||||
func TestRateLimiterBasicAllow(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
// Configure a simple limit
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 10,
|
||||
Window: time.Minute,
|
||||
BurstSize: 5,
|
||||
})
|
||||
|
||||
ip := "192.168.1.100"
|
||||
|
||||
// First 5 requests should be allowed (burst)
|
||||
for i := 0; i < 5; i++ {
|
||||
allowed, reason := rl.Allow(ip, MethodREGISTER)
|
||||
if !allowed {
|
||||
t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterBurstExhaustion(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 10,
|
||||
Window: time.Minute,
|
||||
BurstSize: 3,
|
||||
})
|
||||
|
||||
ip := "192.168.1.101"
|
||||
|
||||
// Exhaust burst
|
||||
for i := 0; i < 3; i++ {
|
||||
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||
if !allowed {
|
||||
t.Errorf("Request %d should be allowed within burst", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Next request should be rate limited (burst exhausted)
|
||||
allowed, reason := rl.Allow(ip, MethodREGISTER)
|
||||
if allowed {
|
||||
t.Error("Request after burst should be rate limited")
|
||||
}
|
||||
if reason != "rate_limit_REGISTER" {
|
||||
t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterTokenRefill(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
// 60 requests per minute = 1 per second
|
||||
rl.SetLimit(MethodOPTIONS, &MethodRateLimit{
|
||||
Method: MethodOPTIONS,
|
||||
MaxRequests: 60,
|
||||
Window: time.Minute,
|
||||
BurstSize: 2,
|
||||
})
|
||||
|
||||
ip := "192.168.1.102"
|
||||
|
||||
// Exhaust burst
|
||||
rl.Allow(ip, MethodOPTIONS)
|
||||
rl.Allow(ip, MethodOPTIONS)
|
||||
|
||||
// Should be blocked now
|
||||
allowed, _ := rl.Allow(ip, MethodOPTIONS)
|
||||
if allowed {
|
||||
t.Error("Should be rate limited after burst")
|
||||
}
|
||||
|
||||
// Wait for token refill (slightly more than 1 second for 1 token)
|
||||
time.Sleep(1100 * time.Millisecond)
|
||||
|
||||
// Should be allowed again
|
||||
allowed, reason := rl.Allow(ip, MethodOPTIONS)
|
||||
if !allowed {
|
||||
t.Errorf("Should be allowed after token refill, reason: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentMethods(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
// Different limits for different methods
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 10,
|
||||
Window: time.Minute,
|
||||
BurstSize: 2,
|
||||
})
|
||||
rl.SetLimit(MethodINVITE, &MethodRateLimit{
|
||||
Method: MethodINVITE,
|
||||
MaxRequests: 30,
|
||||
Window: time.Minute,
|
||||
BurstSize: 5,
|
||||
})
|
||||
|
||||
ip := "192.168.1.103"
|
||||
|
||||
// Exhaust REGISTER burst
|
||||
rl.Allow(ip, MethodREGISTER)
|
||||
rl.Allow(ip, MethodREGISTER)
|
||||
|
||||
// REGISTER should be blocked
|
||||
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||
if allowed {
|
||||
t.Error("REGISTER should be rate limited")
|
||||
}
|
||||
|
||||
// But INVITE should still work (separate bucket)
|
||||
for i := 0; i < 5; i++ {
|
||||
allowed, reason := rl.Allow(ip, MethodINVITE)
|
||||
if !allowed {
|
||||
t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 10,
|
||||
Window: time.Minute,
|
||||
BurstSize: 2,
|
||||
})
|
||||
|
||||
ip1 := "192.168.1.104"
|
||||
ip2 := "192.168.1.105"
|
||||
|
||||
// Exhaust IP1's burst
|
||||
rl.Allow(ip1, MethodREGISTER)
|
||||
rl.Allow(ip1, MethodREGISTER)
|
||||
allowed, _ := rl.Allow(ip1, MethodREGISTER)
|
||||
if allowed {
|
||||
t.Error("IP1 should be rate limited")
|
||||
}
|
||||
|
||||
// IP2 should still work (separate bucket)
|
||||
for i := 0; i < 2; i++ {
|
||||
allowed, reason := rl.Allow(ip2, MethodREGISTER)
|
||||
if !allowed {
|
||||
t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDefaultLimits(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter to get a fresh instance with defaults applied
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
// GetRateLimiter applies default limits from DefaultMethodLimits
|
||||
rl := GetRateLimiter(logger)
|
||||
|
||||
// Test that DefaultMethodLimits are applied
|
||||
for method, expectedLimit := range DefaultMethodLimits {
|
||||
limit := rl.GetLimit(method)
|
||||
if limit.MaxRequests != expectedLimit.MaxRequests {
|
||||
t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) {
|
||||
// NewRateLimiter creates a fresh limiter without default method limits
|
||||
// This is intentional for testing and custom configurations
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
// Should return the global default (100), not method-specific defaults
|
||||
limit := rl.GetLimit(MethodREGISTER)
|
||||
if limit.MaxRequests != 100 {
|
||||
t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterConcurrentAccess(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 100,
|
||||
Window: time.Minute,
|
||||
BurstSize: 50,
|
||||
})
|
||||
|
||||
var wg sync.WaitGroup
|
||||
allowedCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
// Simulate 100 concurrent requests from different IPs
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(n int) {
|
||||
defer wg.Done()
|
||||
ip := "192.168.1." + string(rune('0'+n%10))
|
||||
allowed, _ := rl.Allow(ip, MethodREGISTER)
|
||||
if allowed {
|
||||
mu.Lock()
|
||||
allowedCount++
|
||||
mu.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity)
|
||||
if allowedCount < 90 {
|
||||
t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
// Add some buckets
|
||||
rl.Allow("192.168.1.1", MethodREGISTER)
|
||||
rl.Allow("192.168.1.2", MethodREGISTER)
|
||||
rl.Allow("192.168.1.3", MethodREGISTER)
|
||||
|
||||
stats := rl.GetStats()
|
||||
trackedIPs := stats["tracked_ips"].(int)
|
||||
if trackedIPs != 3 {
|
||||
t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs)
|
||||
}
|
||||
|
||||
// Cleanup shouldn't remove fresh entries
|
||||
rl.Cleanup()
|
||||
stats = rl.GetStats()
|
||||
trackedIPs = stats["tracked_ips"].(int)
|
||||
if trackedIPs != 3 {
|
||||
t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterStats(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
rl := NewRateLimiter(logger)
|
||||
|
||||
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
|
||||
Method: MethodREGISTER,
|
||||
MaxRequests: 10,
|
||||
Window: time.Minute,
|
||||
BurstSize: 5,
|
||||
})
|
||||
|
||||
rl.Allow("192.168.1.100", MethodREGISTER)
|
||||
|
||||
stats := rl.GetStats()
|
||||
|
||||
if _, ok := stats["tracked_ips"]; !ok {
|
||||
t.Error("Stats should include tracked_ips")
|
||||
}
|
||||
if _, ok := stats["limits"]; !ok {
|
||||
t.Error("Stats should include limits")
|
||||
}
|
||||
|
||||
limits := stats["limits"].(map[string]interface{})
|
||||
if _, ok := limits["REGISTER"]; !ok {
|
||||
t.Error("Limits should include REGISTER config")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalRateLimiter(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter for clean test
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
rl := GetRateLimiter(logger)
|
||||
if rl == nil {
|
||||
t.Fatal("GetRateLimiter should return non-nil")
|
||||
}
|
||||
|
||||
// Should return same instance
|
||||
rl2 := GetRateLimiter(logger)
|
||||
if rl != rl2 {
|
||||
t.Error("GetRateLimiter should return same global instance")
|
||||
}
|
||||
|
||||
// Should have default limits applied
|
||||
registerLimit := rl.GetLimit(MethodREGISTER)
|
||||
if registerLimit.MaxRequests != 10 {
|
||||
t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SIP Parsing Function Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestExtractSIPMethod(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expected SIPMethod
|
||||
}{
|
||||
{"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER},
|
||||
{"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE},
|
||||
{"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS},
|
||||
{"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK},
|
||||
{"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE},
|
||||
{"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL},
|
||||
{"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO},
|
||||
{"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY},
|
||||
{"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE},
|
||||
{"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE},
|
||||
{"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE},
|
||||
{"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK},
|
||||
{"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER},
|
||||
{"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH},
|
||||
{"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""},
|
||||
{"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
method := ExtractSIPMethod(tt.data)
|
||||
if method != tt.expected {
|
||||
t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Note: TestExtractTargetExtension is already defined in enumeration_test.go
|
||||
// These additional tests cover edge cases for rate limiter integration
|
||||
|
||||
func TestExtractTargetExtensionEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expected string
|
||||
}{
|
||||
// Legitimate extension extractions (additional cases)
|
||||
{
|
||||
name: "Simple 4-digit extension",
|
||||
data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"),
|
||||
expected: "1001",
|
||||
},
|
||||
{
|
||||
name: "3-digit extension",
|
||||
data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"),
|
||||
expected: "200",
|
||||
},
|
||||
{
|
||||
name: "Alphanumeric short name",
|
||||
data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"),
|
||||
expected: "john",
|
||||
},
|
||||
// Should NOT extract these as extensions (domain-like or too long)
|
||||
{
|
||||
name: "Full domain should not be extracted",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
|
||||
expected: "", // Contains a dot, filtered out
|
||||
},
|
||||
{
|
||||
name: "Too long identifier",
|
||||
data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"),
|
||||
expected: "", // >10 chars
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ext := ExtractTargetExtension(tt.data)
|
||||
if ext != tt.expected {
|
||||
t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUserAgent(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Yealink phone",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
|
||||
expected: "Yealink SIP-T46S 66.86.0.15",
|
||||
},
|
||||
{
|
||||
name: "Linphone",
|
||||
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" +
|
||||
"User-Agent: Linphone/4.5.0\r\n"),
|
||||
expected: "Linphone/4.5.0",
|
||||
},
|
||||
{
|
||||
name: "No User-Agent",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1\r\n"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ua := ParseUserAgent(tt.data)
|
||||
if ua != tt.expected {
|
||||
t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFromHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectedUser string
|
||||
expectedDomain string
|
||||
}{
|
||||
{
|
||||
name: "Standard From header",
|
||||
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"From: \"Alice\" <sip:alice@sip.example.com>;tag=1234\r\n"),
|
||||
expectedUser: "alice",
|
||||
expectedDomain: "sip.example.com",
|
||||
},
|
||||
{
|
||||
name: "From header without display name",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:1001@pbx.local>;tag=abc\r\n"),
|
||||
expectedUser: "1001",
|
||||
expectedDomain: "pbx.local",
|
||||
},
|
||||
{
|
||||
name: "No From header",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||
"To: <sip:bob@example.com>\r\n"),
|
||||
expectedUser: "",
|
||||
expectedDomain: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, domain := ParseFromHeader(tt.data)
|
||||
if user != tt.expectedUser {
|
||||
t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser)
|
||||
}
|
||||
if domain != tt.expectedDomain {
|
||||
t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseToHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expectedUser string
|
||||
expectedDomain string
|
||||
}{
|
||||
{
|
||||
name: "Standard To header",
|
||||
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"To: \"Bob\" <sip:bob@sip.example.com>\r\n"),
|
||||
expectedUser: "bob",
|
||||
expectedDomain: "sip.example.com",
|
||||
},
|
||||
{
|
||||
name: "To header without display name",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"To: <sip:5002@pbx.local>\r\n"),
|
||||
expectedUser: "5002",
|
||||
expectedDomain: "pbx.local",
|
||||
},
|
||||
{
|
||||
name: "No To header",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:alice@example.com>\r\n"),
|
||||
expectedUser: "",
|
||||
expectedDomain: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
user, domain := ParseToHeader(tt.data)
|
||||
if user != tt.expectedUser {
|
||||
t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser)
|
||||
}
|
||||
if domain != tt.expectedDomain {
|
||||
t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCallID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Standard Call-ID header",
|
||||
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"Call-ID: 12345-67890@192.168.1.100\r\n"),
|
||||
expected: "12345-67890@192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "Short form Call-ID (i:)",
|
||||
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"i: compact-callid@host\r\n"),
|
||||
expected: "compact-callid@host",
|
||||
},
|
||||
{
|
||||
name: "No Call-ID",
|
||||
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:alice@example.com>\r\n"),
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
callID := ParseCallID(tt.data)
|
||||
if callID != tt.expected {
|
||||
t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Real-world Legitimate Traffic Scenarios
|
||||
// =============================================================================
|
||||
|
||||
func TestLegitimatePhoneRegistration(t *testing.T) {
|
||||
// Simulate a typical phone registration pattern:
|
||||
// 1. Initial REGISTER (usually 401/407 challenge)
|
||||
// 2. Second REGISTER with auth
|
||||
// 3. Keep-alive OPTIONS
|
||||
// 4. Re-REGISTER before expiry
|
||||
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter for clean test
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
rl := GetRateLimiter(logger)
|
||||
ip := "192.168.1.200"
|
||||
|
||||
// Simulate pattern over time
|
||||
allowedCount := 0
|
||||
|
||||
// Initial REGISTER
|
||||
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
// Auth REGISTER
|
||||
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
// Keep-alive OPTIONS (phones do this frequently)
|
||||
for i := 0; i < 5; i++ {
|
||||
if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
}
|
||||
// Another REGISTER
|
||||
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
|
||||
// At minimum, the phone should complete registration (2 REGISTER + OPTIONS)
|
||||
if allowedCount < 5 {
|
||||
t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateCallFlow(t *testing.T) {
|
||||
// Simulate a typical call flow:
|
||||
// INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK
|
||||
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
rl := GetRateLimiter(logger)
|
||||
ip := "192.168.1.201"
|
||||
|
||||
// INVITE
|
||||
allowed, reason := rl.Allow(ip, MethodINVITE)
|
||||
if !allowed {
|
||||
t.Errorf("INVITE should be allowed for call: %s", reason)
|
||||
}
|
||||
|
||||
// ACK (not rate limited by default)
|
||||
allowed, reason = rl.Allow(ip, MethodACK)
|
||||
if !allowed {
|
||||
t.Errorf("ACK should be allowed for call: %s", reason)
|
||||
}
|
||||
|
||||
// BYE
|
||||
allowed, reason = rl.Allow(ip, MethodBYE)
|
||||
if !allowed {
|
||||
t.Errorf("BYE should be allowed for call: %s", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateSubscriptionFlow(t *testing.T) {
|
||||
// Simulate presence/BLF subscription
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
rl := GetRateLimiter(logger)
|
||||
ip := "192.168.1.202"
|
||||
|
||||
// SUBSCRIBE for presence (phones do multiple for BLF)
|
||||
allowedCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Default allows 5 burst + some refill
|
||||
if allowedCount < 5 {
|
||||
t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLegitimateMessaging(t *testing.T) {
|
||||
// Simulate SMS/messaging traffic
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Reset global rate limiter
|
||||
rateLimiterMu.Lock()
|
||||
globalRateLimiter = nil
|
||||
rateLimiterMu.Unlock()
|
||||
|
||||
rl := GetRateLimiter(logger)
|
||||
ip := "192.168.1.203"
|
||||
|
||||
// MESSAGE (higher limit - 100/min default)
|
||||
allowedCount := 0
|
||||
for i := 0; i < 25; i++ {
|
||||
if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed {
|
||||
allowedCount++
|
||||
}
|
||||
}
|
||||
|
||||
// Should allow most messages (20 burst default)
|
||||
if allowedCount < 20 {
|
||||
t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount)
|
||||
}
|
||||
}
|
||||
110
sipguardian.go
110
sipguardian.go
@ -41,6 +41,11 @@ type SIPGuardian struct {
|
||||
BanTime caddy.Duration `json:"ban_time,omitempty"`
|
||||
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
|
||||
|
||||
// DNS-aware whitelist configuration
|
||||
WhitelistHosts []string `json:"whitelist_hosts,omitempty"` // Hostnames to resolve (A/AAAA)
|
||||
WhitelistSRV []string `json:"whitelist_srv,omitempty"` // SRV records to resolve
|
||||
DNSRefresh caddy.Duration `json:"dns_refresh,omitempty"` // DNS refresh interval (default: 5m)
|
||||
|
||||
// Webhook configuration
|
||||
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
|
||||
|
||||
@ -63,6 +68,7 @@ type SIPGuardian struct {
|
||||
bannedIPs map[string]*BanEntry
|
||||
failureCounts map[string]*failureTracker
|
||||
whitelistNets []*net.IPNet
|
||||
dnsWhitelist *DNSWhitelist
|
||||
mu sync.RWMutex
|
||||
storage *Storage
|
||||
geoIP *GeoIPLookup
|
||||
@ -155,6 +161,34 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize DNS whitelist if configured
|
||||
if len(g.WhitelistHosts) > 0 || len(g.WhitelistSRV) > 0 {
|
||||
refreshInterval := 5 * time.Minute
|
||||
if g.DNSRefresh > 0 {
|
||||
refreshInterval = time.Duration(g.DNSRefresh)
|
||||
}
|
||||
|
||||
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
|
||||
Hostnames: g.WhitelistHosts,
|
||||
SRVRecords: g.WhitelistSRV,
|
||||
RefreshInterval: refreshInterval,
|
||||
AllowStale: true,
|
||||
ResolveTimeout: 10 * time.Second,
|
||||
}, g.logger)
|
||||
|
||||
if err := g.dnsWhitelist.Start(); err != nil {
|
||||
g.logger.Warn("Failed to initialize DNS whitelist",
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
g.logger.Info("DNS whitelist initialized",
|
||||
zap.Int("hostnames", len(g.WhitelistHosts)),
|
||||
zap.Int("srv_records", len(g.WhitelistSRV)),
|
||||
zap.Duration("refresh_interval", refreshInterval),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize enumeration detection with config if specified
|
||||
if g.Enumeration != nil {
|
||||
SetEnumerationConfig(*g.Enumeration)
|
||||
@ -216,12 +250,14 @@ func (g *SIPGuardian) loadBansFromStorage() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsWhitelisted checks if an IP is in the whitelist
|
||||
// IsWhitelisted checks if an IP is in the whitelist (CIDR or DNS-based)
|
||||
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check CIDR-based whitelist
|
||||
for _, network := range g.whitelistNets {
|
||||
if network.Contains(parsedIP) {
|
||||
if enableMetrics {
|
||||
@ -230,6 +266,19 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check DNS-based whitelist
|
||||
if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) {
|
||||
if enableMetrics {
|
||||
RecordWhitelistedConnection()
|
||||
}
|
||||
g.logger.Debug("IP whitelisted via DNS",
|
||||
zap.String("ip", ip),
|
||||
zap.String("source", g.dnsWhitelist.GetSource(ip).Source),
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@ -448,10 +497,34 @@ func (g *SIPGuardian) GetStats() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
stats := map[string]interface{}{
|
||||
"active_bans": activeBans,
|
||||
"tracked_failures": len(g.failureCounts),
|
||||
"whitelist_count": len(g.whitelistNets),
|
||||
"whitelist_cidr": len(g.whitelistNets),
|
||||
"whitelist_hosts": len(g.WhitelistHosts),
|
||||
"whitelist_srv": len(g.WhitelistSRV),
|
||||
}
|
||||
|
||||
// Add DNS whitelist stats if available
|
||||
if g.dnsWhitelist != nil {
|
||||
stats["dns_whitelist"] = g.dnsWhitelist.Stats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetDNSWhitelistEntries returns all resolved DNS whitelist entries
|
||||
func (g *SIPGuardian) GetDNSWhitelistEntries() []ResolvedEntry {
|
||||
if g.dnsWhitelist == nil {
|
||||
return nil
|
||||
}
|
||||
return g.dnsWhitelist.GetResolvedIPs()
|
||||
}
|
||||
|
||||
// RefreshDNSWhitelist forces an immediate refresh of DNS whitelist entries
|
||||
func (g *SIPGuardian) RefreshDNSWhitelist() {
|
||||
if g.dnsWhitelist != nil {
|
||||
g.dnsWhitelist.ForceRefresh()
|
||||
}
|
||||
}
|
||||
|
||||
@ -500,8 +573,15 @@ func (g *SIPGuardian) cleanup() {
|
||||
// max_failures 5
|
||||
// find_time 10m
|
||||
// ban_time 1h
|
||||
//
|
||||
// # IP/CIDR whitelist (static)
|
||||
// whitelist 10.0.0.0/8 192.168.0.0/16
|
||||
//
|
||||
// # DNS-aware whitelist (dynamic, auto-refreshed)
|
||||
// whitelist_hosts pbx.example.com trunk.provider.com
|
||||
// whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
|
||||
// dns_refresh 5m # How often to refresh DNS (default: 5m)
|
||||
//
|
||||
// # Persistent storage
|
||||
// storage /var/lib/sip-guardian/guardian.db
|
||||
//
|
||||
@ -552,10 +632,34 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
g.BanTime = caddy.Duration(dur)
|
||||
|
||||
case "whitelist":
|
||||
// Legacy: CIDR-only whitelist
|
||||
for d.NextArg() {
|
||||
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
|
||||
}
|
||||
|
||||
case "whitelist_hosts":
|
||||
// DNS A/AAAA record whitelist (hostnames resolved to IPs)
|
||||
for d.NextArg() {
|
||||
g.WhitelistHosts = append(g.WhitelistHosts, d.Val())
|
||||
}
|
||||
|
||||
case "whitelist_srv":
|
||||
// DNS SRV record whitelist (e.g., _sip._udp.provider.com)
|
||||
for d.NextArg() {
|
||||
g.WhitelistSRV = append(g.WhitelistSRV, d.Val())
|
||||
}
|
||||
|
||||
case "dns_refresh":
|
||||
// Interval for refreshing DNS-based whitelist entries
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
dur, err := caddy.ParseDuration(d.Val())
|
||||
if err != nil {
|
||||
return d.Errf("invalid dns_refresh: %v", err)
|
||||
}
|
||||
g.DNSRefresh = caddy.Duration(dur)
|
||||
|
||||
case "storage":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user