caddy-sip-guardian/validation.go
Ryan Malloy 976fdf53a5 Add SIP message validation feature
Implements RFC 3261 compliance checking and security validation:

- Three validation modes: permissive (default), strict, paranoid
- Critical checks: null bytes, binary injection (immediate ban)
- RFC compliance: required headers (Via, From, To, Call-ID, CSeq, Max-Forwards)
- Format validation: CSeq range, Content-Length, Via branch format
- Paranoid mode: SQL injection patterns, excessive headers, long values
- Compact header form support (v, f, t, i, l, etc.)

Caddyfile configuration:
  validation {
      enabled true
      mode permissive
      max_message_size 65535
      ban_on_null_bytes true
      ban_on_binary_injection true
      disabled_rules via_invalid_branch
  }

New Prometheus metrics:
- sip_guardian_validation_violations_total{rule}
- sip_guardian_validation_results_total{result}
- sip_guardian_message_size_bytes (histogram)

Includes comprehensive unit tests covering all validation scenarios.
2025-12-07 15:57:26 -07:00

710 lines
18 KiB
Go

package sipguardian
import (
"bytes"
"regexp"
"strconv"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// ValidationConfig holds configuration for SIP message validation
type ValidationConfig struct {
// Enabled determines if validation is active
Enabled bool `json:"enabled,omitempty"`
// Mode determines validation strictness: permissive, strict, paranoid
Mode ValidationMode `json:"mode,omitempty"`
// MaxMessageSize limits SIP message size (default: 65535)
MaxMessageSize int `json:"max_message_size,omitempty"`
// BanOnNullBytes causes immediate ban for null byte injection
BanOnNullBytes bool `json:"ban_on_null_bytes,omitempty"`
// BanOnBinaryInjection causes immediate ban for binary data injection
BanOnBinaryInjection bool `json:"ban_on_binary_injection,omitempty"`
// DisabledRules lists rules to skip
DisabledRules []string `json:"disabled_rules,omitempty"`
}
// ValidationMode determines validation strictness
type ValidationMode string
const (
// ValidationModePermissive logs violations but only blocks critical attacks
ValidationModePermissive ValidationMode = "permissive"
// ValidationModeStrict enforces RFC 3261 compliance
ValidationModeStrict ValidationMode = "strict"
// ValidationModeParanoid applies extra security heuristics
ValidationModeParanoid ValidationMode = "paranoid"
)
// ViolationSeverity indicates how serious a violation is
type ViolationSeverity string
const (
SeverityCritical ViolationSeverity = "critical" // Immediate ban
SeverityHigh ViolationSeverity = "high" // Count toward ban
SeverityMedium ViolationSeverity = "medium" // Reject only
SeverityLow ViolationSeverity = "low" // Log only
)
// Violation represents a single validation rule violation
type Violation struct {
Rule string `json:"rule"`
Severity ViolationSeverity `json:"severity"`
Message string `json:"message"`
Details string `json:"details,omitempty"`
}
// ValidationResult contains the result of validating a SIP message
type ValidationResult struct {
Valid bool `json:"valid"`
Violations []Violation `json:"violations,omitempty"`
ShouldBan bool `json:"should_ban"`
BanReason string `json:"ban_reason,omitempty"`
}
// SIPValidator validates SIP messages against RFC 3261 and security rules
type SIPValidator struct {
config ValidationConfig
disabledSet map[string]bool
logger *zap.Logger
mu sync.RWMutex
stats validatorStats
}
type validatorStats struct {
totalValidated int64
totalValid int64
totalInvalid int64
violationsByRule map[string]int64
mu sync.Mutex
}
// DefaultValidationConfig returns sensible defaults
func DefaultValidationConfig() ValidationConfig {
return ValidationConfig{
Enabled: true,
Mode: ValidationModePermissive,
MaxMessageSize: 65535,
BanOnNullBytes: true,
BanOnBinaryInjection: true,
DisabledRules: []string{},
}
}
// Global validator instance
var (
globalValidator *SIPValidator
validatorMu sync.RWMutex
)
// GetValidator returns the global validator instance
func GetValidator(logger *zap.Logger) *SIPValidator {
validatorMu.Lock()
defer validatorMu.Unlock()
if globalValidator == nil {
globalValidator = NewSIPValidator(logger, DefaultValidationConfig())
}
return globalValidator
}
// SetValidationConfig updates the global validator configuration
func SetValidationConfig(config ValidationConfig) {
validatorMu.Lock()
defer validatorMu.Unlock()
if globalValidator == nil {
globalValidator = &SIPValidator{
config: config,
disabledSet: make(map[string]bool),
logger: zap.NewNop(),
stats: validatorStats{
violationsByRule: make(map[string]int64),
},
}
} else {
globalValidator.config = config
}
// Rebuild disabled rules set
globalValidator.disabledSet = make(map[string]bool)
for _, rule := range config.DisabledRules {
globalValidator.disabledSet[rule] = true
}
}
// NewSIPValidator creates a new validator
func NewSIPValidator(logger *zap.Logger, config ValidationConfig) *SIPValidator {
disabledSet := make(map[string]bool)
for _, rule := range config.DisabledRules {
disabledSet[rule] = true
}
return &SIPValidator{
config: config,
disabledSet: disabledSet,
logger: logger,
stats: validatorStats{
violationsByRule: make(map[string]int64),
},
}
}
// Validate checks a SIP message for RFC compliance and security issues
func (v *SIPValidator) Validate(data []byte) ValidationResult {
v.mu.RLock()
config := v.config
disabled := v.disabledSet
v.mu.RUnlock()
if !config.Enabled {
return ValidationResult{Valid: true}
}
result := ValidationResult{Valid: true}
var violations []Violation
// Critical checks first (can cause immediate ban)
if !disabled["null_bytes"] {
if v := v.checkNullBytes(data); v != nil {
violations = append(violations, *v)
if config.BanOnNullBytes {
result.ShouldBan = true
result.BanReason = "validation_null_bytes"
}
}
}
if !disabled["binary_injection"] {
if v := v.checkBinaryInjection(data); v != nil {
violations = append(violations, *v)
if config.BanOnBinaryInjection {
result.ShouldBan = true
result.BanReason = "validation_binary_injection"
}
}
}
// Size check
if !disabled["oversized_message"] {
if v := v.checkMessageSize(data, config.MaxMessageSize); v != nil {
violations = append(violations, *v)
}
}
// Parse headers for further validation
headers := v.parseHeaders(data)
// Required header checks
if !disabled["missing_via"] {
if v := v.checkRequiredHeader(headers, "Via", SeverityHigh); v != nil {
violations = append(violations, *v)
}
}
if !disabled["missing_from"] {
if v := v.checkRequiredHeader(headers, "From", SeverityHigh); v != nil {
violations = append(violations, *v)
}
}
if !disabled["missing_to"] {
if v := v.checkRequiredHeader(headers, "To", SeverityHigh); v != nil {
violations = append(violations, *v)
}
}
if !disabled["missing_call_id"] {
if v := v.checkRequiredHeader(headers, "Call-ID", SeverityHigh); v != nil {
violations = append(violations, *v)
}
}
if !disabled["missing_cseq"] {
if v := v.checkRequiredHeader(headers, "CSeq", SeverityHigh); v != nil {
violations = append(violations, *v)
}
}
if !disabled["missing_max_forwards"] {
if v := v.checkRequiredHeader(headers, "Max-Forwards", SeverityMedium); v != nil {
violations = append(violations, *v)
}
}
// CSeq validation
if !disabled["cseq_out_of_range"] {
if v := v.checkCSeq(headers); v != nil {
violations = append(violations, *v)
}
}
// Content-Length validation
if !disabled["content_length_mismatch"] {
if v := v.checkContentLength(headers, data); v != nil {
violations = append(violations, *v)
}
}
// Via branch validation (strict/paranoid mode)
if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid {
if !disabled["via_invalid_branch"] {
if v := v.checkViaBranch(headers); v != nil {
violations = append(violations, *v)
}
}
}
// Request-URI validation (strict/paranoid mode)
if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid {
if !disabled["invalid_request_uri"] {
if v := v.checkRequestURI(data); v != nil {
violations = append(violations, *v)
}
}
}
// Paranoid mode extra checks
if config.Mode == ValidationModeParanoid {
if !disabled["suspicious_headers"] {
violations = append(violations, v.checkSuspiciousHeaders(headers)...)
}
}
// Process violations
if len(violations) > 0 {
result.Valid = false
result.Violations = violations
// Record stats
v.stats.mu.Lock()
v.stats.totalValidated++
v.stats.totalInvalid++
for _, viol := range violations {
v.stats.violationsByRule[viol.Rule]++
}
v.stats.mu.Unlock()
// Determine if violations warrant a ban (in strict/paranoid mode)
if !result.ShouldBan && (config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid) {
for _, viol := range violations {
if viol.Severity == SeverityCritical || viol.Severity == SeverityHigh {
result.ShouldBan = true
result.BanReason = "validation_" + viol.Rule
break
}
}
}
} else {
v.stats.mu.Lock()
v.stats.totalValidated++
v.stats.totalValid++
v.stats.mu.Unlock()
}
return result
}
// checkNullBytes detects null byte injection attacks
func (v *SIPValidator) checkNullBytes(data []byte) *Violation {
if bytes.Contains(data, []byte{0x00}) {
return &Violation{
Rule: "null_bytes",
Severity: SeverityCritical,
Message: "Null byte injection detected",
Details: "SIP message contains null bytes which may indicate an injection attack",
}
}
return nil
}
// checkBinaryInjection detects non-ASCII binary data in headers
func (v *SIPValidator) checkBinaryInjection(data []byte) *Violation {
// Split headers from body
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
headerSection := parts[0]
// Check for non-printable ASCII in headers (except CR, LF, TAB)
for i, b := range headerSection {
if b < 0x09 || (b > 0x0D && b < 0x20) || b > 0x7E {
// Allow UTF-8 in display names, but flag truly binary data
if b < 0x80 {
return &Violation{
Rule: "binary_injection",
Severity: SeverityCritical,
Message: "Binary data injection detected",
Details: "Non-printable character 0x" + strconv.FormatInt(int64(b), 16) + " at position " + strconv.Itoa(i),
}
}
}
}
return nil
}
// checkMessageSize validates message doesn't exceed limits
func (v *SIPValidator) checkMessageSize(data []byte, maxSize int) *Violation {
if len(data) > maxSize {
return &Violation{
Rule: "oversized_message",
Severity: SeverityHigh,
Message: "Message exceeds maximum size",
Details: "Size " + strconv.Itoa(len(data)) + " exceeds limit " + strconv.Itoa(maxSize),
}
}
return nil
}
// parseHeaders extracts headers from SIP message
func (v *SIPValidator) parseHeaders(data []byte) map[string][]string {
headers := make(map[string][]string)
// Split headers from body
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
headerSection := string(parts[0])
lines := strings.Split(headerSection, "\r\n")
// Skip request/status line
for i := 1; i < len(lines); i++ {
line := lines[i]
if line == "" {
continue
}
// Handle header continuation (line starting with whitespace)
if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') {
// Append to previous header
continue
}
// Parse header name: value
colonIdx := strings.Index(line, ":")
if colonIdx > 0 {
name := strings.TrimSpace(line[:colonIdx])
value := strings.TrimSpace(line[colonIdx+1:])
// Normalize header name (compact forms)
name = normalizeHeaderName(name)
headers[name] = append(headers[name], value)
}
}
return headers
}
// normalizeHeaderName handles SIP compact header forms
func normalizeHeaderName(name string) string {
// Map compact forms to full names
compactForms := map[string]string{
"i": "Call-ID",
"m": "Contact",
"e": "Content-Encoding",
"l": "Content-Length",
"c": "Content-Type",
"f": "From",
"s": "Subject",
"k": "Supported",
"t": "To",
"v": "Via",
}
lower := strings.ToLower(name)
if full, ok := compactForms[lower]; ok {
return full
}
// Normalize case for common headers
switch lower {
case "via":
return "Via"
case "from":
return "From"
case "to":
return "To"
case "call-id":
return "Call-ID"
case "cseq":
return "CSeq"
case "max-forwards":
return "Max-Forwards"
case "content-length":
return "Content-Length"
case "contact":
return "Contact"
default:
return name
}
}
// checkRequiredHeader validates a required header is present
func (v *SIPValidator) checkRequiredHeader(headers map[string][]string, name string, severity ViolationSeverity) *Violation {
if values, ok := headers[name]; !ok || len(values) == 0 || values[0] == "" {
return &Violation{
Rule: "missing_" + strings.ToLower(strings.ReplaceAll(name, "-", "_")),
Severity: severity,
Message: "Required header missing: " + name,
}
}
return nil
}
// checkCSeq validates CSeq header format and range
func (v *SIPValidator) checkCSeq(headers map[string][]string) *Violation {
values, ok := headers["CSeq"]
if !ok || len(values) == 0 {
return nil // Missing CSeq handled by required header check
}
cseq := values[0]
parts := strings.Fields(cseq)
if len(parts) < 2 {
return &Violation{
Rule: "cseq_invalid_format",
Severity: SeverityHigh,
Message: "Invalid CSeq format",
Details: "Expected 'sequence method', got: " + cseq,
}
}
seq, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return &Violation{
Rule: "cseq_invalid_sequence",
Severity: SeverityHigh,
Message: "CSeq sequence number invalid",
Details: "Could not parse: " + parts[0],
}
}
// RFC 3261: CSeq must be 0 to 2^31-1
if seq < 0 || seq > 2147483647 {
return &Violation{
Rule: "cseq_out_of_range",
Severity: SeverityMedium,
Message: "CSeq sequence number out of range",
Details: "Value " + parts[0] + " outside valid range 0-2147483647",
}
}
return nil
}
// checkContentLength validates Content-Length matches actual body
func (v *SIPValidator) checkContentLength(headers map[string][]string, data []byte) *Violation {
values, ok := headers["Content-Length"]
if !ok || len(values) == 0 {
return nil // Optional header
}
claimLength, err := strconv.Atoi(values[0])
if err != nil {
return &Violation{
Rule: "content_length_invalid",
Severity: SeverityHigh,
Message: "Invalid Content-Length value",
Details: "Could not parse: " + values[0],
}
}
// Find actual body
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
actualLength := 0
if len(parts) > 1 {
actualLength = len(parts[1])
}
// Allow some tolerance for line ending differences
if claimLength != actualLength {
// Check if difference is just CRLF vs LF
if abs(claimLength-actualLength) > 2 {
return &Violation{
Rule: "content_length_mismatch",
Severity: SeverityHigh,
Message: "Content-Length doesn't match body size",
Details: "Claimed " + values[0] + ", actual " + strconv.Itoa(actualLength),
}
}
}
return nil
}
// checkViaBranch validates Via branch parameter format
func (v *SIPValidator) checkViaBranch(headers map[string][]string) *Violation {
values, ok := headers["Via"]
if !ok || len(values) == 0 {
return nil
}
via := values[0]
// RFC 3261: branch must start with "z9hG4bK" for 3261-compliant UAs
branchPattern := regexp.MustCompile(`branch=([^;,\s]+)`)
matches := branchPattern.FindStringSubmatch(via)
if len(matches) > 1 {
branch := matches[1]
if !strings.HasPrefix(branch, "z9hG4bK") {
return &Violation{
Rule: "via_invalid_branch",
Severity: SeverityLow,
Message: "Via branch doesn't follow RFC 3261 format",
Details: "Branch should start with 'z9hG4bK', got: " + branch,
}
}
}
return nil
}
// checkRequestURI validates Request-URI format
func (v *SIPValidator) checkRequestURI(data []byte) *Violation {
// Get first line
idx := bytes.Index(data, []byte("\r\n"))
if idx < 0 {
return &Violation{
Rule: "invalid_request_line",
Severity: SeverityHigh,
Message: "Could not parse request line",
}
}
requestLine := string(data[:idx])
parts := strings.Fields(requestLine)
if len(parts) < 3 {
return &Violation{
Rule: "invalid_request_line",
Severity: SeverityHigh,
Message: "Malformed request line",
Details: requestLine,
}
}
// Check if it's a response (starts with SIP/2.0)
if strings.HasPrefix(parts[0], "SIP/") {
return nil // Skip URI validation for responses
}
uri := parts[1]
// Basic SIP URI validation
if !strings.HasPrefix(strings.ToLower(uri), "sip:") &&
!strings.HasPrefix(strings.ToLower(uri), "sips:") &&
!strings.HasPrefix(strings.ToLower(uri), "tel:") {
return &Violation{
Rule: "invalid_request_uri",
Severity: SeverityHigh,
Message: "Invalid Request-URI scheme",
Details: "Expected sip:, sips:, or tel: URI, got: " + uri,
}
}
return nil
}
// checkSuspiciousHeaders detects potentially malicious header patterns
func (v *SIPValidator) checkSuspiciousHeaders(headers map[string][]string) []Violation {
var violations []Violation
// Check for extremely long header values
for name, values := range headers {
for _, val := range values {
if len(val) > 1024 {
violations = append(violations, Violation{
Rule: "suspicious_long_header",
Severity: SeverityMedium,
Message: "Unusually long header value",
Details: name + " has " + strconv.Itoa(len(val)) + " bytes",
})
}
}
}
// Check for excessive number of Via headers (potential loop)
if vias, ok := headers["Via"]; ok && len(vias) > 70 {
violations = append(violations, Violation{
Rule: "suspicious_via_count",
Severity: SeverityMedium,
Message: "Excessive Via headers",
Details: strconv.Itoa(len(vias)) + " Via headers (potential loop)",
})
}
// Check for SQL injection patterns in headers
sqlPatterns := []string{
"' OR '",
"' OR 1=1",
"'; DROP",
"UNION SELECT",
"--",
}
for name, values := range headers {
for _, val := range values {
upperVal := strings.ToUpper(val)
for _, pattern := range sqlPatterns {
if strings.Contains(upperVal, pattern) {
violations = append(violations, Violation{
Rule: "suspicious_sql_injection",
Severity: SeverityHigh,
Message: "Potential SQL injection in header",
Details: name + " contains suspicious pattern",
})
break
}
}
}
}
return violations
}
// GetStats returns validator statistics
func (v *SIPValidator) GetStats() map[string]interface{} {
v.stats.mu.Lock()
defer v.stats.mu.Unlock()
violations := make(map[string]int64)
for rule, count := range v.stats.violationsByRule {
violations[rule] = count
}
return map[string]interface{}{
"total_validated": v.stats.totalValidated,
"total_valid": v.stats.totalValid,
"total_invalid": v.stats.totalInvalid,
"violations_by_rule": violations,
"config_mode": v.config.Mode,
"config_enabled": v.config.Enabled,
}
}
// IsEnabled returns whether validation is enabled
func (v *SIPValidator) IsEnabled() bool {
v.mu.RLock()
defer v.mu.RUnlock()
return v.config.Enabled
}
// abs returns absolute value of an int
func abs(n int) int {
if n < 0 {
return -n
}
return n
}
// Cleanup resets old statistics (called periodically)
func (v *SIPValidator) Cleanup(maxAge time.Duration) {
// For now, validation stats don't need time-based cleanup
// Could add per-IP validation tracking in the future
}