caddy-sip-guardian/l4handler.go
Ryan Malloy 976fdf53a5 Add SIP message validation feature
Implements RFC 3261 compliance checking and security validation:

- Three validation modes: permissive (default), strict, paranoid
- Critical checks: null bytes, binary injection (immediate ban)
- RFC compliance: required headers (Via, From, To, Call-ID, CSeq, Max-Forwards)
- Format validation: CSeq range, Content-Length, Via branch format
- Paranoid mode: SQL injection patterns, excessive headers, long values
- Compact header form support (v, f, t, i, l, etc.)

Caddyfile configuration:
  validation {
      enabled true
      mode permissive
      max_message_size 65535
      ban_on_null_bytes true
      ban_on_binary_injection true
      disabled_rules via_invalid_branch
  }

New Prometheus metrics:
- sip_guardian_validation_violations_total{rule}
- sip_guardian_validation_results_total{result}
- sip_guardian_message_size_bytes (histogram)

Includes comprehensive unit tests covering all validation scenarios.
2025-12-07 15:57:26 -07:00

675 lines
18 KiB
Go

package sipguardian
import (
"bytes"
"fmt"
"io"
"net"
"regexp"
"strings"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/mholt/caddy-l4/layer4"
"go.uber.org/zap"
)
func init() {
caddy.RegisterModule(SIPMatcher{})
caddy.RegisterModule(SIPHandler{})
}
// SIPMatcher matches SIP traffic by inspecting the first bytes
type SIPMatcher struct {
// Match specific SIP methods (REGISTER, INVITE, OPTIONS, etc.)
Methods []string `json:"methods,omitempty"`
methodRegex *regexp.Regexp
}
func (SIPMatcher) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "layer4.matchers.sip",
New: func() caddy.Module { return new(SIPMatcher) },
}
}
func (m *SIPMatcher) Provision(ctx caddy.Context) error {
if len(m.Methods) == 0 {
// Default: match common SIP methods
m.Methods = []string{"REGISTER", "INVITE", "OPTIONS", "ACK", "BYE", "CANCEL", "INFO", "NOTIFY", "SUBSCRIBE", "MESSAGE"}
}
// Build regex for matching SIP methods
pattern := "^(" + strings.Join(m.Methods, "|") + ") sip:"
m.methodRegex = regexp.MustCompile("(?i)" + pattern)
return nil
}
// Match returns true if the connection appears to be SIP traffic
func (m *SIPMatcher) Match(cx *layer4.Connection) (bool, error) {
// Read enough bytes to identify SIP traffic
// We need at least 8 bytes to identify SIP methods (e.g., "REGISTER " or "SIP/2.0 ")
buf := make([]byte, 64)
n, err := io.ReadFull(cx, buf)
if err == io.ErrUnexpectedEOF && n >= 8 {
// Got less than 64 bytes but enough to match - that's fine
buf = buf[:n]
} else if err != nil {
// Return the error so caddy-l4 knows we need more data
// This includes ErrConsumedAllPrefetchedBytes which triggers prefetch
return false, err
}
// Check if it matches a SIP method (REGISTER, INVITE, OPTIONS, etc.)
if m.methodRegex.Match(buf) {
cx.SetVar("sip_peek", buf)
return true, nil
}
// Check for SIP response (starts with "SIP/2.0")
if bytes.HasPrefix(buf, []byte("SIP/2.0")) {
cx.SetVar("sip_peek", buf)
return true, nil
}
return false, nil
}
// SIPHandler is a Layer 4 handler that enforces SIP Guardian rules
type SIPHandler struct {
// Guardian reference (shared across handlers)
GuardianRef string `json:"guardian,omitempty"`
// Upstream address to proxy to
Upstream string `json:"upstream,omitempty"`
// Embedded guardian config parsed from Caddyfile
// This gets applied to the shared guardian during Provision
SIPGuardian
logger *zap.Logger
guardian *SIPGuardian
}
func (SIPHandler) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "layer4.handlers.sip_guardian",
New: func() caddy.Module { return new(SIPHandler) },
}
}
func (h *SIPHandler) Provision(ctx caddy.Context) error {
h.logger = ctx.Logger()
// Get or create a shared guardian instance from the global registry
// Pass our parsed config so the guardian can be configured
guardian, err := GetOrCreateGuardianWithConfig(ctx, "default", &h.SIPGuardian)
if err != nil {
return err
}
h.guardian = guardian
return nil
}
// Handle processes the connection with SIP-aware protection
func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
remoteAddr := cx.RemoteAddr().String()
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
host = remoteAddr
}
// Check if IP is banned
if h.guardian.IsBanned(host) {
h.logger.Debug("Blocked banned IP", zap.String("ip", host))
if enableMetrics {
RecordConnection("blocked")
}
return cx.Close()
}
// Check if IP is whitelisted - skip further checks
if h.guardian.IsWhitelisted(host) {
if enableMetrics {
RecordConnection("allowed")
}
return next.Handle(cx)
}
// Check GeoIP blocking (if configured)
if blocked, country := h.guardian.IsCountryBlocked(host); blocked {
h.logger.Info("Blocked connection from blocked country",
zap.String("ip", host),
zap.String("country", country),
)
if enableMetrics {
RecordConnection("geo_blocked")
}
return cx.Close()
}
// Read data from the connection for suspicious pattern detection
// caddy-l4 replays prefetched data on read, so we can read the full message here
buf := make([]byte, 4096) // Larger buffer for validation
n, err := cx.Read(buf)
if n > 0 {
buf = buf[:n]
h.logger.Debug("Read SIP data for inspection",
zap.String("ip", host),
zap.Int("bytes", n),
)
// Record message size metric
if enableMetrics {
RecordMessageSize(n)
}
// Validate SIP message structure and content
validator := GetValidator(h.logger)
if validator.IsEnabled() {
validationResult := validator.Validate(buf)
// Record metrics for violations
if enableMetrics {
for _, v := range validationResult.Violations {
RecordValidationViolation(v.Rule)
}
if validationResult.Valid {
RecordValidationResult("valid")
} else if validationResult.ShouldBan {
RecordValidationResult("ban")
} else {
RecordValidationResult("invalid")
}
}
if !validationResult.Valid {
h.logger.Warn("SIP validation failed",
zap.String("ip", host),
zap.Int("violation_count", len(validationResult.Violations)),
zap.Bool("should_ban", validationResult.ShouldBan),
)
// Log individual violations at debug level
for _, v := range validationResult.Violations {
h.logger.Debug("Validation violation",
zap.String("ip", host),
zap.String("rule", v.Rule),
zap.String("severity", string(v.Severity)),
zap.String("message", v.Message),
)
}
if validationResult.ShouldBan {
if enableMetrics {
RecordConnection("validation_blocked")
}
h.guardian.RecordFailure(host, validationResult.BanReason)
return cx.Close()
}
// In strict/paranoid mode, any violation rejects the message
// but we already counted it toward ban threshold above via RecordFailure
// For now in permissive mode, log and continue
}
}
// Extract SIP method for rate limiting
method := ExtractSIPMethod(buf)
if method != "" {
// Check rate limit
rl := GetRateLimiter(h.logger)
if allowed, reason := rl.Allow(host, method); !allowed {
h.logger.Warn("Rate limit exceeded",
zap.String("ip", host),
zap.String("method", string(method)),
)
if enableMetrics {
RecordConnection("rate_limited")
}
// Record as failure (may trigger ban)
h.guardian.RecordFailure(host, reason)
return cx.Close()
}
}
// Check for extension enumeration attacks
extension := ExtractTargetExtension(buf)
if extension != "" {
detector := GetEnumerationDetector(h.logger)
result := detector.RecordAttempt(host, extension)
if result.Detected {
h.logger.Warn("Enumeration attack detected",
zap.String("ip", host),
zap.String("reason", result.Reason),
zap.Int("unique_extensions", result.UniqueCount),
zap.Strings("extensions", result.Extensions),
)
if enableMetrics {
RecordEnumerationDetection(result.Reason)
RecordEnumerationExtensions(result.UniqueCount)
RecordConnection("enumeration_blocked")
}
// Store in persistent storage if enabled
if h.guardian.storage != nil {
go h.guardian.storage.RecordEnumerationAttempt(host, result.Reason, result.UniqueCount, result.Extensions)
}
// Emit webhook event
if enableWebhooks {
go EmitEnumerationEvent(h.logger, host, result)
}
// Ban the IP (use enumeration-specific ban time if configured)
h.guardian.RecordFailure(host, "enumeration_"+result.Reason)
return cx.Close()
}
// Update metrics for tracked IPs
if enableMetrics {
stats := detector.GetStats()
if trackedIPs, ok := stats["tracked_ips"].(int); ok {
UpdateEnumerationTrackedIPs(trackedIPs)
}
}
}
// Check for suspicious patterns in the SIP message
suspiciousPattern := detectSuspiciousPattern(buf)
if suspiciousPattern != "" {
h.logger.Warn("Suspicious SIP traffic detected",
zap.String("ip", host),
zap.String("pattern", suspiciousPattern),
zap.ByteString("sample", buf[:min(64, len(buf))]),
)
if enableMetrics {
RecordSuspiciousPattern(suspiciousPattern)
RecordConnection("suspicious")
}
// Store in persistent storage if enabled
if h.guardian.storage != nil {
go h.guardian.storage.RecordSuspiciousPattern(host, suspiciousPattern, string(buf[:min(200, len(buf))]))
}
banned := h.guardian.RecordFailure(host, "suspicious_sip_pattern")
if banned {
h.logger.Warn("IP banned due to suspicious activity",
zap.String("ip", host),
)
return cx.Close()
}
}
} else if err != nil {
h.logger.Debug("Failed to read SIP data for inspection",
zap.String("ip", host),
zap.Error(err),
)
}
// Record successful connection
if enableMetrics {
RecordConnection("allowed")
}
// Continue to next handler
return next.Handle(cx)
}
// suspiciousPatternDefs defines patterns and their names for detection
var suspiciousPatternDefs = []struct {
name string
pattern string
}{
{"sipvicious", "sipvicious"},
{"friendly-scanner", "friendly-scanner"},
{"sipcli", "sipcli"},
{"sip-scan", "sip-scan"},
{"voipbuster", "voipbuster"},
{"asterisk-pbx-scanner", "asterisk pbx"},
{"sipsak", "sipsak"},
{"sundayddr", "sundayddr"},
{"iwar", "iwar"},
{"cseq-flood", "cseq: 1 options"}, // Repeated OPTIONS flood
{"zoiper-spoof", "user-agent: zoiper"},
{"test-extension-100", "sip:100@"},
{"test-extension-1000", "sip:1000@"},
{"null-user", "sip:@"},
{"anonymous", "anonymous@"},
}
// detectSuspiciousPattern checks for common attack patterns and returns the pattern name
func detectSuspiciousPattern(data []byte) string {
lower := strings.ToLower(string(data))
for _, def := range suspiciousPatternDefs {
if strings.Contains(lower, def.pattern) {
return def.name
}
}
return ""
}
// isSuspiciousSIP checks for common attack patterns in SIP traffic (legacy wrapper)
func isSuspiciousSIP(data []byte) bool {
return detectSuspiciousPattern(data) != ""
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
// UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPMatcher.
// Usage in Caddyfile:
//
// @sip sip {
// methods REGISTER INVITE OPTIONS
// }
//
// Or simply: @sip sip
func (m *SIPMatcher) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// Move past "sip" token
d.Next()
// Check for block
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "methods":
m.Methods = d.RemainingArgs()
if len(m.Methods) == 0 {
return d.ArgErr()
}
default:
return d.Errf("unknown sip matcher directive: %s", d.Val())
}
}
return nil
}
// UnmarshalCaddyfile implements caddyfile.Unmarshaler for SIPHandler.
// Usage in Caddyfile:
//
// sip_guardian {
// max_failures 5
// find_time 10m
// ban_time 1h
// whitelist 10.0.0.0/8 172.16.0.0/12
// storage /data/sip-guardian.db
// geoip_db /data/GeoLite2-Country.mmdb
// block_countries CN RU
// allow_countries US CA GB
// enumeration {
// max_extensions 20
// extension_window 5m
// sequential_threshold 5
// rapid_fire_count 10
// rapid_fire_window 30s
// ban_time 2h
// exempt_extensions 100 200 9999
// }
// validation {
// enabled true
// mode permissive # permissive, strict, paranoid
// max_message_size 65535
// ban_on_null_bytes true
// ban_on_binary_injection true
// disabled_rules via_invalid_branch cseq_out_of_range
// }
// webhook http://example.com/hook { ... }
// }
//
// Or simply: sip_guardian (uses defaults)
func (h *SIPHandler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// Move past "sip_guardian" token
d.Next()
// Parse configuration into the embedded SIPGuardian struct
// This config will be applied to the shared guardian during Provision
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "max_failures":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid max_failures: %v", err)
}
h.MaxFailures = val
case "find_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid find_time: %v", err)
}
h.FindTime = caddy.Duration(dur)
case "ban_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid ban_time: %v", err)
}
h.BanTime = caddy.Duration(dur)
case "whitelist":
for d.NextArg() {
h.WhitelistCIDR = append(h.WhitelistCIDR, d.Val())
}
case "storage":
if !d.NextArg() {
return d.ArgErr()
}
h.StoragePath = d.Val()
case "geoip_db":
if !d.NextArg() {
return d.ArgErr()
}
h.GeoIPPath = d.Val()
case "block_countries":
for d.NextArg() {
country := d.Val()
if expanded := ExpandContinentCode(country); expanded != nil {
h.BlockedCountries = append(h.BlockedCountries, expanded...)
} else {
h.BlockedCountries = append(h.BlockedCountries, country)
}
}
case "allow_countries":
for d.NextArg() {
country := d.Val()
if expanded := ExpandContinentCode(country); expanded != nil {
h.AllowedCountries = append(h.AllowedCountries, expanded...)
} else {
h.AllowedCountries = append(h.AllowedCountries, country)
}
}
case "enumeration":
h.Enumeration = &EnumerationConfig{}
for innerNesting := d.Nesting(); d.NextBlock(innerNesting); {
switch d.Val() {
case "max_extensions":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid max_extensions: %v", err)
}
h.Enumeration.MaxExtensions = val
case "extension_window":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid extension_window: %v", err)
}
h.Enumeration.ExtensionWindow = dur
case "sequential_threshold":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid sequential_threshold: %v", err)
}
h.Enumeration.SequentialThreshold = val
case "rapid_fire_count":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid rapid_fire_count: %v", err)
}
h.Enumeration.RapidFireCount = val
case "rapid_fire_window":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid rapid_fire_window: %v", err)
}
h.Enumeration.RapidFireWindow = dur
case "ban_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid enumeration ban_time: %v", err)
}
h.Enumeration.EnumBanTime = dur
case "exempt_extensions":
h.Enumeration.ExemptExtensions = d.RemainingArgs()
default:
return d.Errf("unknown enumeration directive: %s", d.Val())
}
}
case "validation":
h.Validation = &ValidationConfig{}
for innerNesting := d.Nesting(); d.NextBlock(innerNesting); {
switch d.Val() {
case "enabled":
if !d.NextArg() {
return d.ArgErr()
}
h.Validation.Enabled = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
case "mode":
if !d.NextArg() {
return d.ArgErr()
}
mode := ValidationMode(d.Val())
if mode != ValidationModePermissive && mode != ValidationModeStrict && mode != ValidationModeParanoid {
return d.Errf("invalid validation mode: %s (must be permissive, strict, or paranoid)", d.Val())
}
h.Validation.Mode = mode
case "max_message_size":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid max_message_size: %v", err)
}
h.Validation.MaxMessageSize = val
case "ban_on_null_bytes":
if !d.NextArg() {
return d.ArgErr()
}
h.Validation.BanOnNullBytes = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
case "ban_on_binary_injection":
if !d.NextArg() {
return d.ArgErr()
}
h.Validation.BanOnBinaryInjection = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
case "disabled_rules":
h.Validation.DisabledRules = d.RemainingArgs()
default:
return d.Errf("unknown validation directive: %s", d.Val())
}
}
case "webhook":
if !d.NextArg() {
return d.ArgErr()
}
webhook := WebhookConfig{
URL: d.Val(),
}
// Parse webhook block if present
for innerNesting := d.Nesting(); d.NextBlock(innerNesting); {
switch d.Val() {
case "events":
webhook.Events = d.RemainingArgs()
case "secret":
if !d.NextArg() {
return d.ArgErr()
}
webhook.Secret = d.Val()
case "timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid webhook timeout: %v", err)
}
webhook.Timeout = dur
case "header":
args := d.RemainingArgs()
if len(args) != 2 {
return d.Errf("header requires name and value")
}
if webhook.Headers == nil {
webhook.Headers = make(map[string]string)
}
webhook.Headers[args[0]] = args[1]
default:
return d.Errf("unknown webhook directive: %s", d.Val())
}
}
h.Webhooks = append(h.Webhooks, webhook)
default:
return d.Errf("unknown sip_guardian directive: %s", d.Val())
}
}
return nil
}
// Interface guards
var (
_ layer4.ConnMatcher = (*SIPMatcher)(nil)
_ layer4.NextHandler = (*SIPHandler)(nil)
_ caddy.Provisioner = (*SIPMatcher)(nil)
_ caddy.Provisioner = (*SIPHandler)(nil)
_ caddyfile.Unmarshaler = (*SIPMatcher)(nil)
_ caddyfile.Unmarshaler = (*SIPHandler)(nil)
)