package sipguardian import ( "sync" "testing" "time" "go.uber.org/zap" ) // ============================================================================= // Rate Limiter Tests - Ensuring legitimate traffic flows through // ============================================================================= func TestRateLimiterBasicAllow(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) // Configure a simple limit rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 5, }) ip := "192.168.1.100" // First 5 requests should be allowed (burst) for i := 0; i < 5; i++ { allowed, reason := rl.Allow(ip, MethodREGISTER) if !allowed { t.Errorf("Request %d should be allowed within burst, reason: %s", i+1, reason) } } } func TestRateLimiterBurstExhaustion(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 3, }) ip := "192.168.1.101" // Exhaust burst for i := 0; i < 3; i++ { allowed, _ := rl.Allow(ip, MethodREGISTER) if !allowed { t.Errorf("Request %d should be allowed within burst", i+1) } } // Next request should be rate limited (burst exhausted) allowed, reason := rl.Allow(ip, MethodREGISTER) if allowed { t.Error("Request after burst should be rate limited") } if reason != "rate_limit_REGISTER" { t.Errorf("Reason should be rate_limit_REGISTER, got: %s", reason) } } func TestRateLimiterTokenRefill(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) // 60 requests per minute = 1 per second rl.SetLimit(MethodOPTIONS, &MethodRateLimit{ Method: MethodOPTIONS, MaxRequests: 60, Window: time.Minute, BurstSize: 2, }) ip := "192.168.1.102" // Exhaust burst rl.Allow(ip, MethodOPTIONS) rl.Allow(ip, MethodOPTIONS) // Should be blocked now allowed, _ := rl.Allow(ip, MethodOPTIONS) if allowed { t.Error("Should be rate limited after burst") } // Wait for token refill (slightly more than 1 second for 1 token) time.Sleep(1100 * time.Millisecond) // Should be allowed again allowed, reason := rl.Allow(ip, MethodOPTIONS) if !allowed { t.Errorf("Should be allowed after token refill, reason: %s", reason) } } func TestRateLimiterDifferentMethods(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) // Different limits for different methods rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 2, }) rl.SetLimit(MethodINVITE, &MethodRateLimit{ Method: MethodINVITE, MaxRequests: 30, Window: time.Minute, BurstSize: 5, }) ip := "192.168.1.103" // Exhaust REGISTER burst rl.Allow(ip, MethodREGISTER) rl.Allow(ip, MethodREGISTER) // REGISTER should be blocked allowed, _ := rl.Allow(ip, MethodREGISTER) if allowed { t.Error("REGISTER should be rate limited") } // But INVITE should still work (separate bucket) for i := 0; i < 5; i++ { allowed, reason := rl.Allow(ip, MethodINVITE) if !allowed { t.Errorf("INVITE %d should be allowed, reason: %s", i+1, reason) } } } func TestRateLimiterDifferentIPs(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 2, }) ip1 := "192.168.1.104" ip2 := "192.168.1.105" // Exhaust IP1's burst rl.Allow(ip1, MethodREGISTER) rl.Allow(ip1, MethodREGISTER) allowed, _ := rl.Allow(ip1, MethodREGISTER) if allowed { t.Error("IP1 should be rate limited") } // IP2 should still work (separate bucket) for i := 0; i < 2; i++ { allowed, reason := rl.Allow(ip2, MethodREGISTER) if !allowed { t.Errorf("IP2 request %d should be allowed, reason: %s", i+1, reason) } } } func TestRateLimiterDefaultLimits(t *testing.T) { logger := zap.NewNop() // Reset global rate limiter to get a fresh instance with defaults applied rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() // GetRateLimiter applies default limits from DefaultMethodLimits rl := GetRateLimiter(logger) // Test that DefaultMethodLimits are applied for method, expectedLimit := range DefaultMethodLimits { limit := rl.GetLimit(method) if limit.MaxRequests != expectedLimit.MaxRequests { t.Errorf("Default limit for %s: got %d, want %d", method, limit.MaxRequests, expectedLimit.MaxRequests) } } } func TestNewRateLimiterHasNoDefaultLimits(t *testing.T) { // NewRateLimiter creates a fresh limiter without default method limits // This is intentional for testing and custom configurations logger := zap.NewNop() rl := NewRateLimiter(logger) // Should return the global default (100), not method-specific defaults limit := rl.GetLimit(MethodREGISTER) if limit.MaxRequests != 100 { t.Errorf("NewRateLimiter should have global default 100 for unconfigured methods, got %d", limit.MaxRequests) } } func TestRateLimiterConcurrentAccess(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 100, Window: time.Minute, BurstSize: 50, }) var wg sync.WaitGroup allowedCount := 0 var mu sync.Mutex // Simulate 100 concurrent requests from different IPs for i := 0; i < 100; i++ { wg.Add(1) go func(n int) { defer wg.Done() ip := "192.168.1." + string(rune('0'+n%10)) allowed, _ := rl.Allow(ip, MethodREGISTER) if allowed { mu.Lock() allowedCount++ mu.Unlock() } }(i) } wg.Wait() // Should have allowed most requests (10 IPs with burst of 50 each = up to 500 capacity) if allowedCount < 90 { t.Errorf("Expected at least 90 allowed requests, got %d", allowedCount) } } func TestRateLimiterCleanup(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) // Add some buckets rl.Allow("192.168.1.1", MethodREGISTER) rl.Allow("192.168.1.2", MethodREGISTER) rl.Allow("192.168.1.3", MethodREGISTER) stats := rl.GetStats() trackedIPs := stats["tracked_ips"].(int) if trackedIPs != 3 { t.Errorf("Should have 3 tracked IPs, got %d", trackedIPs) } // Cleanup shouldn't remove fresh entries rl.Cleanup() stats = rl.GetStats() trackedIPs = stats["tracked_ips"].(int) if trackedIPs != 3 { t.Errorf("Fresh entries should not be cleaned up, got %d IPs", trackedIPs) } } func TestRateLimiterStats(t *testing.T) { logger := zap.NewNop() rl := NewRateLimiter(logger) rl.SetLimit(MethodREGISTER, &MethodRateLimit{ Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 5, }) rl.Allow("192.168.1.100", MethodREGISTER) stats := rl.GetStats() if _, ok := stats["tracked_ips"]; !ok { t.Error("Stats should include tracked_ips") } if _, ok := stats["limits"]; !ok { t.Error("Stats should include limits") } limits := stats["limits"].(map[string]interface{}) if _, ok := limits["REGISTER"]; !ok { t.Error("Limits should include REGISTER config") } } func TestGlobalRateLimiter(t *testing.T) { logger := zap.NewNop() // Reset global rate limiter for clean test rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() rl := GetRateLimiter(logger) if rl == nil { t.Fatal("GetRateLimiter should return non-nil") } // Should return same instance rl2 := GetRateLimiter(logger) if rl != rl2 { t.Error("GetRateLimiter should return same global instance") } // Should have default limits applied registerLimit := rl.GetLimit(MethodREGISTER) if registerLimit.MaxRequests != 10 { t.Errorf("Global rate limiter should have default REGISTER limit, got %d", registerLimit.MaxRequests) } } // ============================================================================= // SIP Parsing Function Tests // ============================================================================= func TestExtractSIPMethod(t *testing.T) { tests := []struct { name string data []byte expected SIPMethod }{ {"REGISTER", []byte("REGISTER sip:example.com SIP/2.0\r\n"), MethodREGISTER}, {"INVITE", []byte("INVITE sip:alice@example.com SIP/2.0\r\n"), MethodINVITE}, {"OPTIONS", []byte("OPTIONS sip:example.com SIP/2.0\r\n"), MethodOPTIONS}, {"ACK", []byte("ACK sip:alice@example.com SIP/2.0\r\n"), MethodACK}, {"BYE", []byte("BYE sip:alice@example.com SIP/2.0\r\n"), MethodBYE}, {"CANCEL", []byte("CANCEL sip:alice@example.com SIP/2.0\r\n"), MethodCANCEL}, {"INFO", []byte("INFO sip:alice@example.com SIP/2.0\r\n"), MethodINFO}, {"NOTIFY", []byte("NOTIFY sip:alice@example.com SIP/2.0\r\n"), MethodNOTIFY}, {"SUBSCRIBE", []byte("SUBSCRIBE sip:alice@example.com SIP/2.0\r\n"), MethodSUBSCRIBE}, {"MESSAGE", []byte("MESSAGE sip:alice@example.com SIP/2.0\r\n"), MethodMESSAGE}, {"UPDATE", []byte("UPDATE sip:alice@example.com SIP/2.0\r\n"), MethodUPDATE}, {"PRACK", []byte("PRACK sip:alice@example.com SIP/2.0\r\n"), MethodPRACK}, {"REFER", []byte("REFER sip:alice@example.com SIP/2.0\r\n"), MethodREFER}, {"PUBLISH", []byte("PUBLISH sip:alice@example.com SIP/2.0\r\n"), MethodPUBLISH}, {"Response (no method)", []byte("SIP/2.0 200 OK\r\n"), ""}, {"Non-SIP", []byte("GET / HTTP/1.1\r\n"), ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { method := ExtractSIPMethod(tt.data) if method != tt.expected { t.Errorf("ExtractSIPMethod: got %q, want %q", method, tt.expected) } }) } } // Note: TestExtractTargetExtension is already defined in enumeration_test.go // These additional tests cover edge cases for rate limiter integration func TestExtractTargetExtensionEdgeCases(t *testing.T) { tests := []struct { name string data []byte expected string }{ // Legitimate extension extractions (additional cases) { name: "Simple 4-digit extension", data: []byte("REGISTER sip:1001@example.com SIP/2.0\r\n"), expected: "1001", }, { name: "3-digit extension", data: []byte("INVITE sip:200@pbx.local SIP/2.0\r\n"), expected: "200", }, { name: "Alphanumeric short name", data: []byte("INVITE sip:john@example.com SIP/2.0\r\n"), expected: "john", }, // Should NOT extract these as extensions (domain-like or too long) { name: "Full domain should not be extracted", data: []byte("REGISTER sip:example.com SIP/2.0\r\n"), expected: "", // Contains a dot, filtered out }, { name: "Too long identifier", data: []byte("INVITE sip:verylongusername@example.com SIP/2.0\r\n"), expected: "", // >10 chars }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ext := ExtractTargetExtension(tt.data) if ext != tt.expected { t.Errorf("ExtractTargetExtension: got %q, want %q", ext, tt.expected) } }) } } func TestParseUserAgent(t *testing.T) { tests := []struct { name string data []byte expected string }{ { name: "Yealink phone", data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + "User-Agent: Yealink SIP-T46S 66.86.0.15\r\n"), expected: "Yealink SIP-T46S 66.86.0.15", }, { name: "Linphone", data: []byte("INVITE sip:alice@example.com SIP/2.0\r\n" + "User-Agent: Linphone/4.5.0\r\n"), expected: "Linphone/4.5.0", }, { name: "No User-Agent", data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + "Via: SIP/2.0/UDP 192.168.1.1\r\n"), expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ua := ParseUserAgent(tt.data) if ua != tt.expected { t.Errorf("ParseUserAgent: got %q, want %q", ua, tt.expected) } }) } } func TestParseFromHeader(t *testing.T) { tests := []struct { name string data []byte expectedUser string expectedDomain string }{ { name: "Standard From header", data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + "From: \"Alice\" ;tag=1234\r\n"), expectedUser: "alice", expectedDomain: "sip.example.com", }, { name: "From header without display name", data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + "From: ;tag=abc\r\n"), expectedUser: "1001", expectedDomain: "pbx.local", }, { name: "No From header", data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + "To: \r\n"), expectedUser: "", expectedDomain: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { user, domain := ParseFromHeader(tt.data) if user != tt.expectedUser { t.Errorf("ParseFromHeader user: got %q, want %q", user, tt.expectedUser) } if domain != tt.expectedDomain { t.Errorf("ParseFromHeader domain: got %q, want %q", domain, tt.expectedDomain) } }) } } func TestParseToHeader(t *testing.T) { tests := []struct { name string data []byte expectedUser string expectedDomain string }{ { name: "Standard To header", data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + "To: \"Bob\" \r\n"), expectedUser: "bob", expectedDomain: "sip.example.com", }, { name: "To header without display name", data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + "To: \r\n"), expectedUser: "5002", expectedDomain: "pbx.local", }, { name: "No To header", data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + "From: \r\n"), expectedUser: "", expectedDomain: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { user, domain := ParseToHeader(tt.data) if user != tt.expectedUser { t.Errorf("ParseToHeader user: got %q, want %q", user, tt.expectedUser) } if domain != tt.expectedDomain { t.Errorf("ParseToHeader domain: got %q, want %q", domain, tt.expectedDomain) } }) } } func TestParseCallID(t *testing.T) { tests := []struct { name string data []byte expected string }{ { name: "Standard Call-ID header", data: []byte("INVITE sip:bob@example.com SIP/2.0\r\n" + "Call-ID: 12345-67890@192.168.1.100\r\n"), expected: "12345-67890@192.168.1.100", }, { name: "Short form Call-ID (i:)", data: []byte("REGISTER sip:example.com SIP/2.0\r\n" + "i: compact-callid@host\r\n"), expected: "compact-callid@host", }, { name: "No Call-ID", data: []byte("OPTIONS sip:example.com SIP/2.0\r\n" + "From: \r\n"), expected: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { callID := ParseCallID(tt.data) if callID != tt.expected { t.Errorf("ParseCallID: got %q, want %q", callID, tt.expected) } }) } } // ============================================================================= // Real-world Legitimate Traffic Scenarios // ============================================================================= func TestLegitimatePhoneRegistration(t *testing.T) { // Simulate a typical phone registration pattern: // 1. Initial REGISTER (usually 401/407 challenge) // 2. Second REGISTER with auth // 3. Keep-alive OPTIONS // 4. Re-REGISTER before expiry logger := zap.NewNop() // Reset global rate limiter for clean test rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() rl := GetRateLimiter(logger) ip := "192.168.1.200" // Simulate pattern over time allowedCount := 0 // Initial REGISTER if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { allowedCount++ } // Auth REGISTER if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { allowedCount++ } // Keep-alive OPTIONS (phones do this frequently) for i := 0; i < 5; i++ { if allowed, _ := rl.Allow(ip, MethodOPTIONS); allowed { allowedCount++ } } // Another REGISTER if allowed, _ := rl.Allow(ip, MethodREGISTER); allowed { allowedCount++ } // At minimum, the phone should complete registration (2 REGISTER + OPTIONS) if allowedCount < 5 { t.Errorf("Legitimate phone registration pattern blocked too early, only %d requests allowed", allowedCount) } } func TestLegitimateCallFlow(t *testing.T) { // Simulate a typical call flow: // INVITE -> 100 Trying -> 180 Ringing -> 200 OK -> ACK -> (call) -> BYE -> 200 OK logger := zap.NewNop() // Reset global rate limiter rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() rl := GetRateLimiter(logger) ip := "192.168.1.201" // INVITE allowed, reason := rl.Allow(ip, MethodINVITE) if !allowed { t.Errorf("INVITE should be allowed for call: %s", reason) } // ACK (not rate limited by default) allowed, reason = rl.Allow(ip, MethodACK) if !allowed { t.Errorf("ACK should be allowed for call: %s", reason) } // BYE allowed, reason = rl.Allow(ip, MethodBYE) if !allowed { t.Errorf("BYE should be allowed for call: %s", reason) } } func TestLegitimateSubscriptionFlow(t *testing.T) { // Simulate presence/BLF subscription logger := zap.NewNop() // Reset global rate limiter rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() rl := GetRateLimiter(logger) ip := "192.168.1.202" // SUBSCRIBE for presence (phones do multiple for BLF) allowedCount := 0 for i := 0; i < 10; i++ { if allowed, _ := rl.Allow(ip, MethodSUBSCRIBE); allowed { allowedCount++ } } // Default allows 5 burst + some refill if allowedCount < 5 { t.Errorf("Legitimate SUBSCRIBE pattern blocked too early, only %d allowed", allowedCount) } } func TestLegitimateMessaging(t *testing.T) { // Simulate SMS/messaging traffic logger := zap.NewNop() // Reset global rate limiter rateLimiterMu.Lock() globalRateLimiter = nil rateLimiterMu.Unlock() rl := GetRateLimiter(logger) ip := "192.168.1.203" // MESSAGE (higher limit - 100/min default) allowedCount := 0 for i := 0; i < 25; i++ { if allowed, _ := rl.Allow(ip, MethodMESSAGE); allowed { allowedCount++ } } // Should allow most messages (20 burst default) if allowedCount < 20 { t.Errorf("Legitimate MESSAGE pattern blocked too early, only %d allowed", allowedCount) } }