caddy-sip-guardian/ratelimit_test.go
Ryan Malloy 5cf34eb3c0 Add DNS-aware whitelisting feature
Support for whitelisting SIP trunks and providers by hostname or SRV
record with automatic IP resolution and periodic refresh.

Features:
- Hostname resolution via A/AAAA records
- SRV record resolution (e.g., _sip._udp.provider.com)
- Configurable refresh interval (default 5m)
- Stale entry handling when DNS fails
- Admin API endpoints for DNS whitelist management
- Caddyfile directives: whitelist_hosts, whitelist_srv, dns_refresh

This allows whitelisting by provider name rather than tracking
constantly-changing IP addresses.
2025-12-08 00:46:43 -07:00

699 lines
18 KiB
Go

package sipguardian
import (
"sync"
"testing"
"time"
"go.uber.org/zap"
)
// =============================================================================
// Rate Limiter Tests - Ensuring legitimate traffic flows through
// =============================================================================
func TestRateLimiterBasicAllow(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Configure a simple limit
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 5,
})
ip := "192.168.1.100"
// First 5 requests should be allowed (burst)
for i := 0; i < 5; i++ {
allowed, reason := rl.Allow(ip, MethodREGISTER)
if !allowed {
t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterBurstExhaustion(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 3,
})
ip := "192.168.1.101"
// Exhaust burst
for i := 0; i < 3; i++ {
allowed, _ := rl.Allow(ip, MethodREGISTER)
if !allowed {
t.Errorf("Request %d should be allowed within burst", i+1)
}
}
// Next request should be rate limited (burst exhausted)
allowed, reason := rl.Allow(ip, MethodREGISTER)
if allowed {
t.Error("Request after burst should be rate limited")
}
if reason != "rate_limit_REGISTER" {
t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason)
}
}
func TestRateLimiterTokenRefill(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// 60 requests per minute = 1 per second
rl.SetLimit(MethodOPTIONS, &MethodRateLimit{
Method: MethodOPTIONS,
MaxRequests: 60,
Window: time.Minute,
BurstSize: 2,
})
ip := "192.168.1.102"
// Exhaust burst
rl.Allow(ip, MethodOPTIONS)
rl.Allow(ip, MethodOPTIONS)
// Should be blocked now
allowed, _ := rl.Allow(ip, MethodOPTIONS)
if allowed {
t.Error("Should be rate limited after burst")
}
// Wait for token refill (slightly more than 1 second for 1 token)
time.Sleep(1100 * time.Millisecond)
// Should be allowed again
allowed, reason := rl.Allow(ip, MethodOPTIONS)
if !allowed {
t.Errorf("Should be allowed after token refill, reason: %s", reason)
}
}
func TestRateLimiterDifferentMethods(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Different limits for different methods
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 2,
})
rl.SetLimit(MethodINVITE, &MethodRateLimit{
Method: MethodINVITE,
MaxRequests: 30,
Window: time.Minute,
BurstSize: 5,
})
ip := "192.168.1.103"
// Exhaust REGISTER burst
rl.Allow(ip, MethodREGISTER)
rl.Allow(ip, MethodREGISTER)
// REGISTER should be blocked
allowed, _ := rl.Allow(ip, MethodREGISTER)
if allowed {
t.Error("REGISTER should be rate limited")
}
// But INVITE should still work (separate bucket)
for i := 0; i < 5; i++ {
allowed, reason := rl.Allow(ip, MethodINVITE)
if !allowed {
t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 2,
})
ip1 := "192.168.1.104"
ip2 := "192.168.1.105"
// Exhaust IP1's burst
rl.Allow(ip1, MethodREGISTER)
rl.Allow(ip1, MethodREGISTER)
allowed, _ := rl.Allow(ip1, MethodREGISTER)
if allowed {
t.Error("IP1 should be rate limited")
}
// IP2 should still work (separate bucket)
for i := 0; i < 2; i++ {
allowed, reason := rl.Allow(ip2, MethodREGISTER)
if !allowed {
t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason)
}
}
}
func TestRateLimiterDefaultLimits(t *testing.T) {
logger := zap.NewNop()
// Reset global rate limiter to get a fresh instance with defaults applied
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
// GetRateLimiter applies default limits from DefaultMethodLimits
rl := GetRateLimiter(logger)
// Test that DefaultMethodLimits are applied
for method, expectedLimit := range DefaultMethodLimits {
limit := rl.GetLimit(method)
if limit.MaxRequests != expectedLimit.MaxRequests {
t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests)
}
}
}
func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) {
// NewRateLimiter creates a fresh limiter without default method limits
// This is intentional for testing and custom configurations
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Should return the global default (100), not method-specific defaults
limit := rl.GetLimit(MethodREGISTER)
if limit.MaxRequests != 100 {
t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests)
}
}
func TestRateLimiterConcurrentAccess(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 100,
Window: time.Minute,
BurstSize: 50,
})
var wg sync.WaitGroup
allowedCount := 0
var mu sync.Mutex
// Simulate 100 concurrent requests from different IPs
for i := 0; i < 100; i++ {
wg.Add(1)
go func(n int) {
defer wg.Done()
ip := "192.168.1." + string(rune('0'+n%10))
allowed, _ := rl.Allow(ip, MethodREGISTER)
if allowed {
mu.Lock()
allowedCount++
mu.Unlock()
}
}(i)
}
wg.Wait()
// Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity)
if allowedCount < 90 {
t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount)
}
}
func TestRateLimiterCleanup(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
// Add some buckets
rl.Allow("192.168.1.1", MethodREGISTER)
rl.Allow("192.168.1.2", MethodREGISTER)
rl.Allow("192.168.1.3", MethodREGISTER)
stats := rl.GetStats()
trackedIPs := stats["tracked_ips"].(int)
if trackedIPs != 3 {
t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs)
}
// Cleanup shouldn't remove fresh entries
rl.Cleanup()
stats = rl.GetStats()
trackedIPs = stats["tracked_ips"].(int)
if trackedIPs != 3 {
t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs)
}
}
func TestRateLimiterStats(t *testing.T) {
logger := zap.NewNop()
rl := NewRateLimiter(logger)
rl.SetLimit(MethodREGISTER, &MethodRateLimit{
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 5,
})
rl.Allow("192.168.1.100", MethodREGISTER)
stats := rl.GetStats()
if _, ok := stats["tracked_ips"]; !ok {
t.Error("Stats should include tracked_ips")
}
if _, ok := stats["limits"]; !ok {
t.Error("Stats should include limits")
}
limits := stats["limits"].(map[string]interface{})
if _, ok := limits["REGISTER"]; !ok {
t.Error("Limits should include REGISTER config")
}
}
func TestGlobalRateLimiter(t *testing.T) {
logger := zap.NewNop()
// Reset global rate limiter for clean test
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
if rl == nil {
t.Fatal("GetRateLimiter should return non-nil")
}
// Should return same instance
rl2 := GetRateLimiter(logger)
if rl != rl2 {
t.Error("GetRateLimiter should return same global instance")
}
// Should have default limits applied
registerLimit := rl.GetLimit(MethodREGISTER)
if registerLimit.MaxRequests != 10 {
t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests)
}
}
// =============================================================================
// SIP Parsing Function Tests
// =============================================================================
func TestExtractSIPMethod(t *testing.T) {
tests := []struct {
name string
data []byte
expected SIPMethod
}{
{"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER},
{"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE},
{"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS},
{"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK},
{"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE},
{"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL},
{"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO},
{"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY},
{"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE},
{"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE},
{"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE},
{"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK},
{"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER},
{"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH},
{"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""},
{"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
method := ExtractSIPMethod(tt.data)
if method != tt.expected {
t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected)
}
})
}
}
// Note: TestExtractTargetExtension is already defined in enumeration_test.go
// These additional tests cover edge cases for rate limiter integration
func TestExtractTargetExtensionEdgeCases(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
// Legitimate extension extractions (additional cases)
{
name: "Simple 4-digit extension",
data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"),
expected: "1001",
},
{
name: "3-digit extension",
data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"),
expected: "200",
},
{
name: "Alphanumeric short name",
data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"),
expected: "john",
},
// Should NOT extract these as extensions (domain-like or too long)
{
name: "Full domain should not be extracted",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n"),
expected: "", // Contains a dot, filtered out
},
{
name: "Too long identifier",
data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"),
expected: "", // >10 chars
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ext := ExtractTargetExtension(tt.data)
if ext != tt.expected {
t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected)
}
})
}
}
func TestParseUserAgent(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
{
name: "Yealink phone",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"),
expected: "Yealink SIP-T46S 66.86.0.15",
},
{
name: "Linphone",
data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" +
"User-Agent: Linphone/4.5.0\r\n"),
expected: "Linphone/4.5.0",
},
{
name: "No User-Agent",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.1\r\n"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ua := ParseUserAgent(tt.data)
if ua != tt.expected {
t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected)
}
})
}
}
func TestParseFromHeader(t *testing.T) {
tests := []struct {
name string
data []byte
expectedUser string
expectedDomain string
}{
{
name: "Standard From header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"From: \"Alice\" <sip:alice@sip.example.com>;tag=1234\r\n"),
expectedUser: "alice",
expectedDomain: "sip.example.com",
},
{
name: "From header without display name",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"From: <sip:1001@pbx.local>;tag=abc\r\n"),
expectedUser: "1001",
expectedDomain: "pbx.local",
},
{
name: "No From header",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"To: <sip:bob@example.com>\r\n"),
expectedUser: "",
expectedDomain: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, domain := ParseFromHeader(tt.data)
if user != tt.expectedUser {
t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain)
}
})
}
}
func TestParseToHeader(t *testing.T) {
tests := []struct {
name string
data []byte
expectedUser string
expectedDomain string
}{
{
name: "Standard To header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"To: \"Bob\" <sip:bob@sip.example.com>\r\n"),
expectedUser: "bob",
expectedDomain: "sip.example.com",
},
{
name: "To header without display name",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"To: <sip:5002@pbx.local>\r\n"),
expectedUser: "5002",
expectedDomain: "pbx.local",
},
{
name: "No To header",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"From: <sip:alice@example.com>\r\n"),
expectedUser: "",
expectedDomain: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, domain := ParseToHeader(tt.data)
if user != tt.expectedUser {
t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser)
}
if domain != tt.expectedDomain {
t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain)
}
})
}
}
func TestParseCallID(t *testing.T) {
tests := []struct {
name string
data []byte
expected string
}{
{
name: "Standard Call-ID header",
data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
"Call-ID: 12345-67890@192.168.1.100\r\n"),
expected: "12345-67890@192.168.1.100",
},
{
name: "Short form Call-ID (i:)",
data: []byte("REGISTER sip:example.com SIP/2.0\r\n" +
"i: compact-callid@host\r\n"),
expected: "compact-callid@host",
},
{
name: "No Call-ID",
data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" +
"From: <sip:alice@example.com>\r\n"),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
callID := ParseCallID(tt.data)
if callID != tt.expected {
t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected)
}
})
}
}
// =============================================================================
// Real-world Legitimate Traffic Scenarios
// =============================================================================
func TestLegitimatePhoneRegistration(t *testing.T) {
// Simulate a typical phone registration pattern:
// 1. Initial REGISTER (usually 401/407 challenge)
// 2. Second REGISTER with auth
// 3. Keep-alive OPTIONS
// 4. Re-REGISTER before expiry
logger := zap.NewNop()
// Reset global rate limiter for clean test
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.200"
// Simulate pattern over time
allowedCount := 0
// Initial REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// Auth REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// Keep-alive OPTIONS (phones do this frequently)
for i := 0; i < 5; i++ {
if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed {
allowedCount++
}
}
// Another REGISTER
if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed {
allowedCount++
}
// At minimum, the phone should complete registration (2 REGISTER + OPTIONS)
if allowedCount < 5 {
t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount)
}
}
func TestLegitimateCallFlow(t *testing.T) {
// Simulate a typical call flow:
// INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.201"
// INVITE
allowed, reason := rl.Allow(ip, MethodINVITE)
if !allowed {
t.Errorf("INVITE should be allowed for call: %s", reason)
}
// ACK (not rate limited by default)
allowed, reason = rl.Allow(ip, MethodACK)
if !allowed {
t.Errorf("ACK should be allowed for call: %s", reason)
}
// BYE
allowed, reason = rl.Allow(ip, MethodBYE)
if !allowed {
t.Errorf("BYE should be allowed for call: %s", reason)
}
}
func TestLegitimateSubscriptionFlow(t *testing.T) {
// Simulate presence/BLF subscription
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.202"
// SUBSCRIBE for presence (phones do multiple for BLF)
allowedCount := 0
for i := 0; i < 10; i++ {
if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed {
allowedCount++
}
}
// Default allows 5 burst + some refill
if allowedCount < 5 {
t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount)
}
}
func TestLegitimateMessaging(t *testing.T) {
// Simulate SMS/messaging traffic
logger := zap.NewNop()
// Reset global rate limiter
rateLimiterMu.Lock()
globalRateLimiter = nil
rateLimiterMu.Unlock()
rl := GetRateLimiter(logger)
ip := "192.168.1.203"
// MESSAGE (higher limit - 100/min default)
allowedCount := 0
for i := 0; i < 25; i++ {
if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed {
allowedCount++
}
}
// Should allow most messages (20 burst default)
if allowedCount < 20 {
t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount)
}
}