package sipguardian import ( "strconv" "testing" "time" "go.uber.org/zap" ) func newTestDetector(config EnumerationConfig) *EnumerationDetector { logger := zap.NewNop() return NewEnumerationDetector(logger, config) } func TestDefaultConfig(t *testing.T) { config := DefaultEnumerationConfig() if config.MaxExtensions != 20 { t.Errorf("Expected MaxExtensions=20, got %d", config.MaxExtensions) } if config.ExtensionWindow != 5*time.Minute { t.Errorf("Expected ExtensionWindow=5m, got %v", config.ExtensionWindow) } if config.SequentialThreshold != 5 { t.Errorf("Expected SequentialThreshold=5, got %d", config.SequentialThreshold) } if config.RapidFireCount != 10 { t.Errorf("Expected RapidFireCount=10, got %d", config.RapidFireCount) } if config.RapidFireWindow != 30*time.Second { t.Errorf("Expected RapidFireWindow=30s, got %v", config.RapidFireWindow) } } func TestMaxExtensionsDetection(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 5, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 100, // Disable sequential detection RapidFireCount: 100, // Disable rapid-fire detection RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.100" // Record 4 different extensions - should not trigger for i := 0; i < 4; i++ { ext := strconv.Itoa(1000 + i*100) // Non-sequential: 1000, 1100, 1200, 1300 result := detector.RecordAttempt(ip, ext) if result.Detected { t.Errorf("Should not detect on extension %d (count=%d)", i+1, result.UniqueCount) } } // 5th unique extension should trigger result := detector.RecordAttempt(ip, "2000") if !result.Detected { t.Error("Should detect when max_extensions reached") } if result.Reason != "extension_count_exceeded" { t.Errorf("Expected reason 'extension_count_exceeded', got '%s'", result.Reason) } if result.UniqueCount != 5 { t.Errorf("Expected unique_count=5, got %d", result.UniqueCount) } } func TestSequentialPatternDetection(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 100, // High to avoid triggering count-based ExtensionWindow: 5 * time.Minute, SequentialThreshold: 5, RapidFireCount: 100, // Disable rapid-fire RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.101" // Record sequential extensions: 100, 101, 102, 103 for i := 100; i <= 103; i++ { result := detector.RecordAttempt(ip, strconv.Itoa(i)) if result.Detected { t.Errorf("Should not detect on extension %d", i) } } // 5th sequential should trigger result := detector.RecordAttempt(ip, "104") if !result.Detected { t.Error("Should detect sequential pattern at 5 consecutive") } if result.Reason != "sequential_enumeration" { t.Errorf("Expected reason 'sequential_enumeration', got '%s'", result.Reason) } if result.SeqStart != 100 || result.SeqEnd != 104 { t.Errorf("Expected SeqStart=100, SeqEnd=104, got %d-%d", result.SeqStart, result.SeqEnd) } } func TestSequentialPatternGaps(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 100, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 5, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.102" // Non-sequential extensions with gaps extensions := []string{"100", "102", "104", "106", "108"} for _, ext := range extensions { result := detector.RecordAttempt(ip, ext) if result.Detected && result.Reason == "sequential_enumeration" { t.Errorf("Should not detect sequential pattern for non-consecutive: %s", ext) } } } func TestRapidFireDetection(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 100, // High to avoid triggering count-based ExtensionWindow: 5 * time.Minute, SequentialThreshold: 100, // Disable sequential RapidFireCount: 5, RapidFireWindow: 1 * time.Second, // Short window for testing } detector := newTestDetector(config) ip := "192.168.1.103" // Record 5 different extensions rapidly (within the window) for i := 0; i < 4; i++ { ext := strconv.Itoa(1000 + i*100) // Non-sequential result := detector.RecordAttempt(ip, ext) if result.Detected && result.Reason == "rapid_fire_enumeration" { t.Errorf("Should not detect rapid-fire on attempt %d", i+1) } } // 5th should trigger rapid-fire result := detector.RecordAttempt(ip, "5000") if !result.Detected { t.Error("Should detect rapid-fire pattern") } if result.Reason != "rapid_fire_enumeration" { t.Errorf("Expected reason 'rapid_fire_enumeration', got '%s'", result.Reason) } } func TestExemptExtensions(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 3, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 3, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, ExemptExtensions: []string{"100", "200", "emergency"}, } detector := newTestDetector(config) ip := "192.168.1.104" // Exempt extensions should not count exemptExts := []string{"100", "200", "emergency"} for _, ext := range exemptExts { result := detector.RecordAttempt(ip, ext) if result.Detected { t.Errorf("Exempt extension '%s' should not trigger detection", ext) } } // Non-exempt extensions should still count result := detector.RecordAttempt(ip, "1001") if result.Detected { t.Error("First non-exempt should not trigger") } result = detector.RecordAttempt(ip, "1002") if result.Detected { t.Error("Second non-exempt should not trigger") } result = detector.RecordAttempt(ip, "1003") if !result.Detected { t.Error("Third non-exempt should trigger (max_extensions=3)") } } func TestDuplicateExtensions(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 3, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 100, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.105" // Record same extension multiple times - should only count as 1 for i := 0; i < 10; i++ { result := detector.RecordAttempt(ip, "1000") if result.Detected { t.Error("Duplicate extensions should not trigger detection") } if result.UniqueCount != 1 { t.Errorf("Expected unique_count=1 for duplicates, got %d", result.UniqueCount) } } } func TestMultipleIPsIsolation(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 3, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 100, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip1 := "192.168.1.106" ip2 := "192.168.1.107" // Record extensions for IP1 for i := 0; i < 2; i++ { detector.RecordAttempt(ip1, strconv.Itoa(1000+i)) } // Record extensions for IP2 - should start fresh result := detector.RecordAttempt(ip2, "2000") if result.UniqueCount != 1 { t.Errorf("IP2 should have independent count, expected 1, got %d", result.UniqueCount) } // IP1's 3rd should trigger result = detector.RecordAttempt(ip1, "1002") if !result.Detected { t.Error("IP1 should trigger on 3rd unique extension") } // IP2 should still be fine result = detector.RecordAttempt(ip2, "2001") if result.Detected { t.Error("IP2 should not trigger on 2nd unique extension") } } func TestGetStats(t *testing.T) { config := DefaultEnumerationConfig() detector := newTestDetector(config) // Initial stats stats := detector.GetStats() if stats["tracked_ips"].(int) != 0 { t.Errorf("Expected tracked_ips=0 initially, got %d", stats["tracked_ips"]) } // Record some attempts detector.RecordAttempt("192.168.1.1", "1000") detector.RecordAttempt("192.168.1.2", "2000") stats = detector.GetStats() if stats["tracked_ips"].(int) != 2 { t.Errorf("Expected tracked_ips=2, got %d", stats["tracked_ips"]) } if stats["total_extensions"].(int) != 2 { t.Errorf("Expected total_extensions=2, got %d", stats["total_extensions"]) } } func TestGetIPAttempts(t *testing.T) { config := DefaultEnumerationConfig() detector := newTestDetector(config) ip := "192.168.1.108" // No attempts yet result := detector.GetIPAttempts(ip) if result != nil { t.Error("Expected nil for non-tracked IP") } // Record some attempts detector.RecordAttempt(ip, "1000") detector.RecordAttempt(ip, "1001") detector.RecordAttempt(ip, "1002") result = detector.GetIPAttempts(ip) if result == nil { t.Fatal("Expected result for tracked IP") } if result.UniqueCount != 3 { t.Errorf("Expected unique_count=3, got %d", result.UniqueCount) } if len(result.Extensions) != 3 { t.Errorf("Expected 3 extensions, got %d", len(result.Extensions)) } } func TestResetIP(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 3, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 100, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.109" // Record 2 extensions detector.RecordAttempt(ip, "1000") detector.RecordAttempt(ip, "1001") // Reset the IP detector.ResetIP(ip) // Should start fresh - no detection yet result := detector.RecordAttempt(ip, "2000") if result.UniqueCount != 1 { t.Errorf("After reset, expected unique_count=1, got %d", result.UniqueCount) } } func TestNonNumericExtensions(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 100, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 3, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.110" // Non-numeric extensions should not trigger sequential detection nonNumeric := []string{"sales", "support", "main", "fax", "reception"} for _, ext := range nonNumeric { result := detector.RecordAttempt(ip, ext) if result.Detected && result.Reason == "sequential_enumeration" { t.Errorf("Non-numeric '%s' should not trigger sequential detection", ext) } } } func TestMixedNumericNonNumeric(t *testing.T) { config := EnumerationConfig{ MaxExtensions: 100, ExtensionWindow: 5 * time.Minute, SequentialThreshold: 5, RapidFireCount: 100, RapidFireWindow: 30 * time.Second, } detector := newTestDetector(config) ip := "192.168.1.111" // Mix of numeric sequential with non-numeric interruptions // Still should detect sequence in numeric ones extensions := []string{"100", "main", "101", "support", "102", "sales", "103", "104"} var detectedSeq bool for _, ext := range extensions { result := detector.RecordAttempt(ip, ext) if result.Detected && result.Reason == "sequential_enumeration" { detectedSeq = true } } if !detectedSeq { t.Error("Should detect 5 sequential numeric extensions even with non-numeric mixed in") } } func TestExtractTargetExtension(t *testing.T) { testCases := []struct { name string data string expected string }{ { name: "REGISTER with extension", data: "REGISTER sip:1001@example.com SIP/2.0\r\nVia: SIP/2.0/UDP 192.168.1.1\r\n", expected: "1001", }, { name: "INVITE with extension", data: "INVITE sip:2000@pbx.local SIP/2.0\r\nFrom: \r\n", expected: "2000", }, { name: "OPTIONS with extension", data: "OPTIONS sip:100@domain.com SIP/2.0\r\n", expected: "100", }, { name: "Extension too long (should skip)", data: "REGISTER sip:verylongextensionname@example.com SIP/2.0\r\n", expected: "", }, { name: "Domain-like user (should skip)", data: "REGISTER sip:example.com@example.com SIP/2.0\r\n", expected: "", }, { name: "BYE method (not tracked)", data: "BYE sip:1001@example.com SIP/2.0\r\n", expected: "", }, { name: "Fallback to To header", data: "ACK sip:anything@example.com SIP/2.0\r\nTo: \r\n", expected: "500", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { result := ExtractTargetExtension([]byte(tc.data)) if result != tc.expected { t.Errorf("Expected '%s', got '%s'", tc.expected, result) } }) } }