package sipguardian import ( "bytes" "fmt" "io" "net" "regexp" "strings" "unicode" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/mholt/caddy-l4/layer4" "go.uber.org/zap" ) func init() { caddy.RegisterModule(SIPMatcher{}) caddy.RegisterModule(SIPHandler{}) } // SIPMatcher matches SIP traffic by inspecting the first bytes type SIPMatcher struct { // Match specific SIP methods (REGISTER, INVITE, OPTIONS, etc.) Methods []string `json:"methods,omitempty"` methodRegex *regexp.Regexp } func (SIPMatcher) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "layer4.matchers.sip", New: func() caddy.Module { return new(SIPMatcher) }, } } func (m *SIPMatcher) Provision(ctx caddy.Context) error { if len(m.Methods) == 0 { // Default: match common SIP methods m.Methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"} } // Build regex for matching SIP methods pattern := "^(" + strings.Join(m.Methods, "|") + ") sip:" m.methodRegex = regexp.MustCompile("(?i)" + pattern) return nil } // Match returns true if the connection appears to be SIP traffic func (m *SIPMatcher) Match(cx *layer4.Connection) (bool, error) { // Read enough bytes to identify SIP traffic // We need at least 8 bytes to identify SIP methods (e.g., "REGISTER " or "SIP/2.0 ") buf := make([]byte, 64) n, err := io.ReadFull(cx, buf) if err == io.ErrUnexpectedEOF && n >= 8 { // Got less than 64 bytes but enough to match - that's fine buf = buf[:n] } else if err != nil { // Return the error so caddy-l4 knows we need more data // This includes ErrConsumedAllPrefetchedBytes which triggers prefetch return false, err } // Check if it matches a SIP method (REGISTER, INVITE, OPTIONS, etc.) if m.methodRegex.Match(buf) { cx.SetVar("sip_peek", buf) return true, nil } // Check for SIP response (starts with "SIP/2.0") if bytes.HasPrefix(buf, []byte("SIP/2.0")) { cx.SetVar("sip_peek", buf) return true, nil } return false, nil } // SIPHandler is a Layer 4 handler that enforces SIP Guardian rules type SIPHandler struct { // Guardian reference (shared across handlers) GuardianRef string `json:"guardian,omitempty"` // Upstream address to proxy to Upstream string `json:"upstream,omitempty"` // Embedded guardian config parsed from Caddyfile // This gets applied to the shared guardian during Provision SIPGuardian logger *zap.Logger guardian *SIPGuardian } func (SIPHandler) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "layer4.handlers.sip_guardian", New: func() caddy.Module { return new(SIPHandler) }, } } func (h *SIPHandler) Provision(ctx caddy.Context) error { h.logger = ctx.Logger() // Get or create a shared guardian instance from the global registry // Pass our parsed config so the guardian can be configured guardian, err := GetOrCreateGuardianWithConfig(ctx, "default", &h.SIPGuardian) if err != nil { return err } h.guardian = guardian return nil } // Handle processes the connection with SIP-aware protection func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { remoteAddr := cx.RemoteAddr().String() host, _, err := net.SplitHostPort(remoteAddr) if err != nil { host = remoteAddr } // Check if IP is banned if h.guardian.IsBanned(host) { h.logger.Debug("Blocked banned IP", zap.String("ip", host)) if enableMetrics { RecordConnection("blocked") } return cx.Close() } // Check if IP is whitelisted - skip further checks if h.guardian.IsWhitelisted(host) { if enableMetrics { RecordConnection("allowed") } return next.Handle(cx) } // Check GeoIP blocking (if configured) if blocked, country := h.guardian.IsCountryBlocked(host); blocked { h.logger.Info("Blocked connection from blocked country", zap.String("ip", host), zap.String("country", country), ) if enableMetrics { RecordConnection("geo_blocked") } return cx.Close() } // Read data from the connection for suspicious pattern detection // caddy-l4 replays prefetched data on read, so we can read the full message here buf := make([]byte, 4096) // Larger buffer for validation n, err := cx.Read(buf) if n > 0 { buf = buf[:n] h.logger.Debug("Read SIP data for inspection", zap.String("ip", host), zap.Int("bytes", n), ) // Record message size metric if enableMetrics { RecordMessageSize(n) } // Validate SIP message structure and content validator := GetValidator(h.logger) if validator.IsEnabled() { validationResult := validator.Validate(buf) // Record metrics for violations if enableMetrics { for _, v := range validationResult.Violations { RecordValidationViolation(v.Rule) } if validationResult.Valid { RecordValidationResult("valid") } else if validationResult.ShouldBan { RecordValidationResult("ban") } else { RecordValidationResult("invalid") } } if !validationResult.Valid { h.logger.Warn("SIP validation failed", zap.String("ip", host), zap.Int("violation_count", len(validationResult.Violations)), zap.Bool("should_ban", validationResult.ShouldBan), ) // Log individual violations at debug level for _, v := range validationResult.Violations { h.logger.Debug("Validation violation", zap.String("ip", host), zap.String("rule", v.Rule), zap.String("severity", string(v.Severity)), zap.String("message", v.Message), ) } if validationResult.ShouldBan { if enableMetrics { RecordConnection("validation_blocked") } h.guardian.RecordFailure(host, validationResult.BanReason) return cx.Close() } // In strict/paranoid mode, any violation rejects the message // but we already counted it toward ban threshold above via RecordFailure // For now in permissive mode, log and continue } } // Extract SIP method for rate limiting method := ExtractSIPMethod(buf) if method != "" { // Check rate limit rl := GetRateLimiter(h.logger) if allowed, reason := rl.Allow(host, method); !allowed { h.logger.Warn("Rate limit exceeded", zap.String("ip", host), zap.String("method", string(method)), ) if enableMetrics { RecordConnection("rate_limited") } // Record as failure (may trigger ban) h.guardian.RecordFailure(host, reason) return cx.Close() } } // Check for extension enumeration attacks extension := ExtractTargetExtension(buf) if extension != "" { detector := GetEnumerationDetector(h.logger) result := detector.RecordAttempt(host, extension) if result.Detected { h.logger.Warn("Enumeration attack detected", zap.String("ip", host), zap.String("reason", result.Reason), zap.Int("unique_extensions", result.UniqueCount), zap.Strings("extensions", result.Extensions), ) if enableMetrics { RecordEnumerationDetection(result.Reason) RecordEnumerationExtensions(result.UniqueCount) RecordConnection("enumeration_blocked") } // Store in persistent storage if enabled if h.guardian.storage != nil { go h.guardian.storage.RecordEnumerationAttempt(host, result.Reason, result.UniqueCount, result.Extensions) } // Emit webhook event if enableWebhooks { go EmitEnumerationEvent(h.logger, host, result) } // Ban the IP (use enumeration-specific ban time if configured) h.guardian.RecordFailure(host, "enumeration_"+result.Reason) return cx.Close() } // Update metrics for tracked IPs if enableMetrics { stats := detector.GetStats() if trackedIPs, ok := stats["tracked_ips"].(int); ok { UpdateEnumerationTrackedIPs(trackedIPs) } } } // Check for suspicious patterns in the SIP message suspiciousPattern := detectSuspiciousPattern(buf) if suspiciousPattern != "" { h.logger.Warn("Suspicious SIP traffic detected", zap.String("ip", host), zap.String("pattern", suspiciousPattern), zap.ByteString("sample", buf[:min(64, len(buf))]), ) if enableMetrics { RecordSuspiciousPattern(suspiciousPattern) RecordConnection("suspicious") } // Store in persistent storage if enabled if h.guardian.storage != nil { go h.guardian.storage.RecordSuspiciousPattern(host, suspiciousPattern, string(buf[:min(200, len(buf))])) } banned := h.guardian.RecordFailure(host, "suspicious_sip_pattern") if banned { h.logger.Warn("IP banned due to suspicious activity", zap.String("ip", host), ) return cx.Close() } } } else if err != nil { h.logger.Debug("Failed to read SIP data for inspection", zap.String("ip", host), zap.Error(err), ) } // Record successful connection if enableMetrics { RecordConnection("allowed") } // Continue to next handler return next.Handle(cx) } // suspiciousPatternDefs defines patterns and their names for detection // IMPORTANT: Patterns must be specific enough to avoid false positives on legitimate traffic // Patterns are pre-converted to lowercase for efficient case-insensitive matching var suspiciousPatternDefs = []struct { name string pattern []byte }{ {"sipvicious", []byte("sipvicious")}, {"friendly-scanner", []byte("friendly-scanner")}, {"sipcli", []byte("sipcli")}, {"sip-scan", []byte("sip-scan")}, {"voipbuster", []byte("voipbuster")}, // Note: "asterisk pbx scanner" pattern removed - too broad, catches legitimate Asterisk PBX systems // The original pattern "asterisk pbx" would match "User-Agent: Asterisk PBX 18.0" which is legitimate {"sipsak", []byte("sipsak")}, {"sundayddr", []byte("sundayddr")}, {"iwar", []byte("iwar")}, // Note: "cseq: 1 options" pattern REMOVED - too broad, catches ANY first OPTIONS request // OPTIONS with CSeq 1 is completely normal - it's the first OPTIONS from any client // Use rate limiting for OPTIONS flood detection instead {"test-extension-100", []byte("sip:100@")}, {"test-extension-1000", []byte("sip:1000@")}, {"null-user", []byte("sip:@")}, {"anonymous", []byte("anonymous@")}, } // detectSuspiciousPattern checks for common attack patterns and returns the pattern name // Uses zero-allocation case-insensitive byte matching for performance on hot path func detectSuspiciousPattern(data []byte) string { for _, def := range suspiciousPatternDefs { if bytesContainsCI(data, def.pattern) { return def.name } } return "" } // bytesContainsCI performs case-insensitive byte slice search without allocations func bytesContainsCI(haystack, needle []byte) bool { if len(needle) == 0 { return true } if len(haystack) < len(needle) { return false } // Search for needle in haystack (case-insensitive) for i := 0; i <= len(haystack)-len(needle); i++ { if bytesEqualCI(haystack[i:i+len(needle)], needle) { return true } } return false } // bytesEqualCI compares two byte slices case-insensitively without allocations func bytesEqualCI(a, b []byte) bool { if len(a) != len(b) { return false } for i := 0; i < len(a); i++ { if unicode.ToLower(rune(a[i])) != unicode.ToLower(rune(b[i])) { return false } } return true } // isSuspiciousSIP checks for common attack patterns in SIP traffic (legacy wrapper) func isSuspiciousSIP(data []byte) bool { return detectSuspiciousPattern(data) != "" } // UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPMatcher. // Usage in Caddyfile: // // @sip sip { // methods REGISTER INVITE OPTIONS // } // // Or simply: @sip sip func (m *SIPMatcher) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // Move past "sip" token d.Next() // Check for block for nesting := d.Nesting(); d.NextBlock(nesting); { switch d.Val() { case "methods": m.Methods = d.RemainingArgs() if len(m.Methods) == 0 { return d.ArgErr() } default: return d.Errf("unknown sip matcher directive: %s", d.Val()) } } return nil } // UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPHandler. // Usage in Caddyfile: // // sip_guardian { // max_failures 5 // find_time 10m // ban_time 1h // whitelist 10.0.0.0/8 172.16.0.0/12 // storage /data/sip-guardian.db // geoip_db /data/GeoLite2-Country.mmdb // block_countries CN RU // allow_countries US CA GB // enumeration { // max_extensions 20 // extension_window 5m // sequential_threshold 5 // rapid_fire_count 10 // rapid_fire_window 30s // ban_time 2h // exempt_extensions 100 200 9999 // } // validation { // enabled true // mode permissive # permissive, strict, paranoid // max_message_size 65535 // ban_on_null_bytes true // ban_on_binary_injection true // disabled_rules via_invalid_branch cseq_out_of_range // } // webhook http://example.com/hook { ... } // } // // Or simply: sip_guardian (uses defaults) func (h *SIPHandler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // Move past "sip_guardian" token d.Next() // Parse configuration into the embedded SIPGuardian struct // This config will be applied to the shared guardian during Provision for nesting := d.Nesting(); d.NextBlock(nesting); { switch d.Val() { case "max_failures": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid max_failures: %v", err) } h.MaxFailures = val case "find_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid find_time: %v", err) } h.FindTime = caddy.Duration(dur) case "ban_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid ban_time: %v", err) } h.BanTime = caddy.Duration(dur) case "whitelist": for d.NextArg() { h.WhitelistCIDR = append(h.WhitelistCIDR, d.Val()) } case "storage": if !d.NextArg() { return d.ArgErr() } h.StoragePath = d.Val() case "geoip_db": if !d.NextArg() { return d.ArgErr() } h.GeoIPPath = d.Val() case "block_countries": for d.NextArg() { country := d.Val() if expanded := ExpandContinentCode(country); expanded != nil { h.BlockedCountries = append(h.BlockedCountries, expanded...) } else { h.BlockedCountries = append(h.BlockedCountries, country) } } case "allow_countries": for d.NextArg() { country := d.Val() if expanded := ExpandContinentCode(country); expanded != nil { h.AllowedCountries = append(h.AllowedCountries, expanded...) } else { h.AllowedCountries = append(h.AllowedCountries, country) } } case "enumeration": h.Enumeration = &EnumerationConfig{} for innerNesting := d.Nesting(); d.NextBlock(innerNesting); { switch d.Val() { case "max_extensions": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid max_extensions: %v", err) } h.Enumeration.MaxExtensions = val case "extension_window": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid extension_window: %v", err) } h.Enumeration.ExtensionWindow = dur case "sequential_threshold": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid sequential_threshold: %v", err) } h.Enumeration.SequentialThreshold = val case "rapid_fire_count": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid rapid_fire_count: %v", err) } h.Enumeration.RapidFireCount = val case "rapid_fire_window": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid rapid_fire_window: %v", err) } h.Enumeration.RapidFireWindow = dur case "ban_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid enumeration ban_time: %v", err) } h.Enumeration.EnumBanTime = dur case "exempt_extensions": h.Enumeration.ExemptExtensions = d.RemainingArgs() default: return d.Errf("unknown enumeration directive: %s", d.Val()) } } case "validation": h.Validation = &ValidationConfig{} for innerNesting := d.Nesting(); d.NextBlock(innerNesting); { switch d.Val() { case "enabled": if !d.NextArg() { return d.ArgErr() } h.Validation.Enabled = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on" case "mode": if !d.NextArg() { return d.ArgErr() } mode := ValidationMode(d.Val()) if mode != ValidationModePermissive && mode != ValidationModeStrict && mode != ValidationModeParanoid { return d.Errf("invalid validation mode: %s (must be permissive, strict, or paranoid)", d.Val()) } h.Validation.Mode = mode case "max_message_size": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid max_message_size: %v", err) } h.Validation.MaxMessageSize = val case "ban_on_null_bytes": if !d.NextArg() { return d.ArgErr() } h.Validation.BanOnNullBytes = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on" case "ban_on_binary_injection": if !d.NextArg() { return d.ArgErr() } h.Validation.BanOnBinaryInjection = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on" case "disabled_rules": h.Validation.DisabledRules = d.RemainingArgs() default: return d.Errf("unknown validation directive: %s", d.Val()) } } case "webhook": if !d.NextArg() { return d.ArgErr() } webhook := WebhookConfig{ URL: d.Val(), } // Parse webhook block if present for innerNesting := d.Nesting(); d.NextBlock(innerNesting); { switch d.Val() { case "events": webhook.Events = d.RemainingArgs() case "secret": if !d.NextArg() { return d.ArgErr() } webhook.Secret = d.Val() case "timeout": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid webhook timeout: %v", err) } webhook.Timeout = dur case "header": args := d.RemainingArgs() if len(args) != 2 { return d.Errf("header requires name and value") } if webhook.Headers == nil { webhook.Headers = make(map[string]string) } webhook.Headers[args[0]] = args[1] default: return d.Errf("unknown webhook directive: %s", d.Val()) } } h.Webhooks = append(h.Webhooks, webhook) default: return d.Errf("unknown sip_guardian directive: %s", d.Val()) } } return nil } // Interface guards var ( _ layer4.ConnMatcher = (*SIPMatcher)(nil) _ layer4.NextHandler = (*SIPHandler)(nil) _ caddy.Provisioner = (*SIPMatcher)(nil) _ caddy.Provisioner = (*SIPHandler)(nil) _ caddyfile.Unmarshaler = (*SIPMatcher)(nil) _ caddyfile.Unmarshaler = (*SIPHandler)(nil) )