package sipguardian import ( "bytes" "fmt" "io" "net" "regexp" "strings" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/mholt/caddy-l4/layer4" "go.uber.org/zap" ) func init() { caddy.RegisterModule(SIPMatcher{}) caddy.RegisterModule(SIPHandler{}) } // SIPMatcher matches SIP traffic by inspecting the first bytes type SIPMatcher struct { // Match specific SIP methods (REGISTER, INVITE, OPTIONS, etc.) Methods []string `json:"methods,omitempty"` methodRegex *regexp.Regexp } func (SIPMatcher) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "layer4.matchers.sip", New: func() caddy.Module { return new(SIPMatcher) }, } } func (m *SIPMatcher) Provision(ctx caddy.Context) error { if len(m.Methods) == 0 { // Default: match common SIP methods m.Methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"} } // Build regex for matching SIP methods pattern := "^(" + strings.Join(m.Methods, "|") + ") sip:" m.methodRegex = regexp.MustCompile("(?i)" + pattern) return nil } // Match returns true if the connection appears to be SIP traffic func (m *SIPMatcher) Match(cx *layer4.Connection) (bool, error) { // Read enough bytes to identify SIP traffic // We need at least 8 bytes to identify SIP methods (e.g., "REGISTER " or "SIP/2.0 ") buf := make([]byte, 64) n, err := io.ReadFull(cx, buf) if err == io.ErrUnexpectedEOF && n >= 8 { // Got less than 64 bytes but enough to match - that's fine buf = buf[:n] } else if err != nil { // Return the error so caddy-l4 knows we need more data // This includes ErrConsumedAllPrefetchedBytes which triggers prefetch return false, err } // Check if it matches a SIP method (REGISTER, INVITE, OPTIONS, etc.) if m.methodRegex.Match(buf) { cx.SetVar("sip_peek", buf) return true, nil } // Check for SIP response (starts with "SIP/2.0") if bytes.HasPrefix(buf, []byte("SIP/2.0")) { cx.SetVar("sip_peek", buf) return true, nil } return false, nil } // SIPHandler is a Layer 4 handler that enforces SIP Guardian rules type SIPHandler struct { // Guardian reference (shared across handlers) GuardianRef string `json:"guardian,omitempty"` // Upstream address to proxy to Upstream string `json:"upstream,omitempty"` // Embedded guardian config parsed from Caddyfile // This gets applied to the shared guardian during Provision SIPGuardian logger *zap.Logger guardian *SIPGuardian } func (SIPHandler) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "layer4.handlers.sip_guardian", New: func() caddy.Module { return new(SIPHandler) }, } } func (h *SIPHandler) Provision(ctx caddy.Context) error { h.logger = ctx.Logger() // Get or create a shared guardian instance from the global registry // Pass our parsed config so the guardian can be configured guardian, err := GetOrCreateGuardianWithConfig(ctx, "default", &h.SIPGuardian) if err != nil { return err } h.guardian = guardian return nil } // Handle processes the connection with SIP-aware protection func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { remoteAddr := cx.RemoteAddr().String() host, _, err := net.SplitHostPort(remoteAddr) if err != nil { host = remoteAddr } // Check if IP is banned if h.guardian.IsBanned(host) { h.logger.Debug("Blocked banned IP", zap.String("ip", host)) if enableMetrics { RecordConnection("blocked") } return cx.Close() } // Check if IP is whitelisted - skip further checks if h.guardian.IsWhitelisted(host) { if enableMetrics { RecordConnection("allowed") } return next.Handle(cx) } // Check GeoIP blocking (if configured) if blocked, country := h.guardian.IsCountryBlocked(host); blocked { h.logger.Info("Blocked connection from blocked country", zap.String("ip", host), zap.String("country", country), ) if enableMetrics { RecordConnection("geo_blocked") } return cx.Close() } // Read data from the connection for suspicious pattern detection // caddy-l4 replays prefetched data on read, so we can read the full message here buf := make([]byte, 1024) n, err := cx.Read(buf) if n > 0 { buf = buf[:n] h.logger.Debug("Read SIP data for inspection", zap.String("ip", host), zap.Int("bytes", n), ) // Extract SIP method for rate limiting method := ExtractSIPMethod(buf) if method != "" { // Check rate limit rl := GetRateLimiter(h.logger) if allowed, reason := rl.Allow(host, method); !allowed { h.logger.Warn("Rate limit exceeded", zap.String("ip", host), zap.String("method", string(method)), ) if enableMetrics { RecordConnection("rate_limited") } // Record as failure (may trigger ban) h.guardian.RecordFailure(host, reason) return cx.Close() } } // Check for extension enumeration attacks extension := ExtractTargetExtension(buf) if extension != "" { detector := GetEnumerationDetector(h.logger) result := detector.RecordAttempt(host, extension) if result.Detected { h.logger.Warn("Enumeration attack detected", zap.String("ip", host), zap.String("reason", result.Reason), zap.Int("unique_extensions", result.UniqueCount), zap.Strings("extensions", result.Extensions), ) if enableMetrics { RecordEnumerationDetection(result.Reason) RecordEnumerationExtensions(result.UniqueCount) RecordConnection("enumeration_blocked") } // Store in persistent storage if enabled if h.guardian.storage != nil { go h.guardian.storage.RecordEnumerationAttempt(host, result.Reason, result.UniqueCount, result.Extensions) } // Emit webhook event if enableWebhooks { go EmitEnumerationEvent(h.logger, host, result) } // Ban the IP (use enumeration-specific ban time if configured) h.guardian.RecordFailure(host, "enumeration_"+result.Reason) return cx.Close() } // Update metrics for tracked IPs if enableMetrics { stats := detector.GetStats() if trackedIPs, ok := stats["tracked_ips"].(int); ok { UpdateEnumerationTrackedIPs(trackedIPs) } } } // Check for suspicious patterns in the SIP message suspiciousPattern := detectSuspiciousPattern(buf) if suspiciousPattern != "" { h.logger.Warn("Suspicious SIP traffic detected", zap.String("ip", host), zap.String("pattern", suspiciousPattern), zap.ByteString("sample", buf[:min(64, len(buf))]), ) if enableMetrics { RecordSuspiciousPattern(suspiciousPattern) RecordConnection("suspicious") } // Store in persistent storage if enabled if h.guardian.storage != nil { go h.guardian.storage.RecordSuspiciousPattern(host, suspiciousPattern, string(buf[:min(200, len(buf))])) } banned := h.guardian.RecordFailure(host, "suspicious_sip_pattern") if banned { h.logger.Warn("IP banned due to suspicious activity", zap.String("ip", host), ) return cx.Close() } } } else if err != nil { h.logger.Debug("Failed to read SIP data for inspection", zap.String("ip", host), zap.Error(err), ) } // Record successful connection if enableMetrics { RecordConnection("allowed") } // Continue to next handler return next.Handle(cx) } // suspiciousPatternDefs defines patterns and their names for detection var suspiciousPatternDefs = []struct { name string pattern string }{ {"sipvicious", "sipvicious"}, {"friendly-scanner", "friendly-scanner"}, {"sipcli", "sipcli"}, {"sip-scan", "sip-scan"}, {"voipbuster", "voipbuster"}, {"asterisk-pbx-scanner", "asterisk pbx"}, {"sipsak", "sipsak"}, {"sundayddr", "sundayddr"}, {"iwar", "iwar"}, {"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood {"zoiper-spoof", "user-agent: zoiper"}, {"test-extension-100", "sip:100@"}, {"test-extension-1000", "sip:1000@"}, {"null-user", "sip:@"}, {"anonymous", "anonymous@"}, } // detectSuspiciousPattern checks for common attack patterns and returns the pattern name func detectSuspiciousPattern(data []byte) string { lower := strings.ToLower(string(data)) for _, def := range suspiciousPatternDefs { if strings.Contains(lower, def.pattern) { return def.name } } return "" } // isSuspiciousSIP checks for common attack patterns in SIP traffic (legacy wrapper) func isSuspiciousSIP(data []byte) bool { return detectSuspiciousPattern(data) != "" } func min(a, b int) int { if a < b { return a } return b } // UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPMatcher. // Usage in Caddyfile: // // @sip sip { // methods REGISTER INVITE OPTIONS // } // // Or simply: @sip sip func (m *SIPMatcher) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // Move past "sip" token d.Next() // Check for block for nesting := d.Nesting(); d.NextBlock(nesting); { switch d.Val() { case "methods": m.Methods = d.RemainingArgs() if len(m.Methods) == 0 { return d.ArgErr() } default: return d.Errf("unknown sip matcher directive: %s", d.Val()) } } return nil } // UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPHandler. // Usage in Caddyfile: // // sip_guardian { // max_failures 5 // find_time 10m // ban_time 1h // whitelist 10.0.0.0/8 172.16.0.0/12 // storage /data/sip-guardian.db // geoip_db /data/GeoLite2-Country.mmdb // block_countries CN RU // allow_countries US CA GB // enumeration { // max_extensions 20 // extension_window 5m // sequential_threshold 5 // rapid_fire_count 10 // rapid_fire_window 30s // ban_time 2h // exempt_extensions 100 200 9999 // } // webhook http://example.com/hook { ... } // } // // Or simply: sip_guardian (uses defaults) func (h *SIPHandler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { // Move past "sip_guardian" token d.Next() // Parse configuration into the embedded SIPGuardian struct // This config will be applied to the shared guardian during Provision for nesting := d.Nesting(); d.NextBlock(nesting); { switch d.Val() { case "max_failures": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid max_failures: %v", err) } h.MaxFailures = val case "find_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid find_time: %v", err) } h.FindTime = caddy.Duration(dur) case "ban_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid ban_time: %v", err) } h.BanTime = caddy.Duration(dur) case "whitelist": for d.NextArg() { h.WhitelistCIDR = append(h.WhitelistCIDR, d.Val()) } case "storage": if !d.NextArg() { return d.ArgErr() } h.StoragePath = d.Val() case "geoip_db": if !d.NextArg() { return d.ArgErr() } h.GeoIPPath = d.Val() case "block_countries": for d.NextArg() { country := d.Val() if expanded := ExpandContinentCode(country); expanded != nil { h.BlockedCountries = append(h.BlockedCountries, expanded...) } else { h.BlockedCountries = append(h.BlockedCountries, country) } } case "allow_countries": for d.NextArg() { country := d.Val() if expanded := ExpandContinentCode(country); expanded != nil { h.AllowedCountries = append(h.AllowedCountries, expanded...) } else { h.AllowedCountries = append(h.AllowedCountries, country) } } case "enumeration": h.Enumeration = &EnumerationConfig{} for innerNesting := d.Nesting(); d.NextBlock(innerNesting); { switch d.Val() { case "max_extensions": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid max_extensions: %v", err) } h.Enumeration.MaxExtensions = val case "extension_window": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid extension_window: %v", err) } h.Enumeration.ExtensionWindow = dur case "sequential_threshold": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid sequential_threshold: %v", err) } h.Enumeration.SequentialThreshold = val case "rapid_fire_count": if !d.NextArg() { return d.ArgErr() } var val int if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil { return d.Errf("invalid rapid_fire_count: %v", err) } h.Enumeration.RapidFireCount = val case "rapid_fire_window": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid rapid_fire_window: %v", err) } h.Enumeration.RapidFireWindow = dur case "ban_time": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid enumeration ban_time: %v", err) } h.Enumeration.EnumBanTime = dur case "exempt_extensions": h.Enumeration.ExemptExtensions = d.RemainingArgs() default: return d.Errf("unknown enumeration directive: %s", d.Val()) } } case "webhook": if !d.NextArg() { return d.ArgErr() } webhook := WebhookConfig{ URL: d.Val(), } // Parse webhook block if present for innerNesting := d.Nesting(); d.NextBlock(innerNesting); { switch d.Val() { case "events": webhook.Events = d.RemainingArgs() case "secret": if !d.NextArg() { return d.ArgErr() } webhook.Secret = d.Val() case "timeout": if !d.NextArg() { return d.ArgErr() } dur, err := caddy.ParseDuration(d.Val()) if err != nil { return d.Errf("invalid webhook timeout: %v", err) } webhook.Timeout = dur case "header": args := d.RemainingArgs() if len(args) != 2 { return d.Errf("header requires name and value") } if webhook.Headers == nil { webhook.Headers = make(map[string]string) } webhook.Headers[args[0]] = args[1] default: return d.Errf("unknown webhook directive: %s", d.Val()) } } h.Webhooks = append(h.Webhooks, webhook) default: return d.Errf("unknown sip_guardian directive: %s", d.Val()) } } return nil } // Interface guards var ( _ layer4.ConnMatcher = (*SIPMatcher)(nil) _ layer4.NextHandler = (*SIPHandler)(nil) _ caddy.Provisioner = (*SIPMatcher)(nil) _ caddy.Provisioner = (*SIPHandler)(nil) _ caddyfile.Unmarshaler = (*SIPMatcher)(nil) _ caddyfile.Unmarshaler = (*SIPHandler)(nil) )