package sipguardian import ( "bytes" "regexp" "strconv" "strings" "sync" "time" "go.uber.org/zap" ) // ValidationConfig holds configuration for SIP message validation type ValidationConfig struct { // Enabled determines if validation is active Enabled bool `json:"enabled,omitempty"` // Mode determines validation strictness: permissive, strict, paranoid Mode ValidationMode `json:"mode,omitempty"` // MaxMessageSize limits SIP message size (default: 65535) MaxMessageSize int `json:"max_message_size,omitempty"` // BanOnNullBytes causes immediate ban for null byte injection BanOnNullBytes bool `json:"ban_on_null_bytes,omitempty"` // BanOnBinaryInjection causes immediate ban for binary data injection BanOnBinaryInjection bool `json:"ban_on_binary_injection,omitempty"` // DisabledRules lists rules to skip DisabledRules []string `json:"disabled_rules,omitempty"` } // ValidationMode determines validation strictness type ValidationMode string const ( // ValidationModePermissive logs violations but only blocks critical attacks ValidationModePermissive ValidationMode = "permissive" // ValidationModeStrict enforces RFC 3261 compliance ValidationModeStrict ValidationMode = "strict" // ValidationModeParanoid applies extra security heuristics ValidationModeParanoid ValidationMode = "paranoid" ) // ViolationSeverity indicates how serious a violation is type ViolationSeverity string const ( SeverityCritical ViolationSeverity = "critical" // Immediate ban SeverityHigh ViolationSeverity = "high" // Count toward ban SeverityMedium ViolationSeverity = "medium" // Reject only SeverityLow ViolationSeverity = "low" // Log only ) // Violation represents a single validation rule violation type Violation struct { Rule string `json:"rule"` Severity ViolationSeverity `json:"severity"` Message string `json:"message"` Details string `json:"details,omitempty"` } // ValidationResult contains the result of validating a SIP message type ValidationResult struct { Valid bool `json:"valid"` Violations []Violation `json:"violations,omitempty"` ShouldBan bool `json:"should_ban"` BanReason string `json:"ban_reason,omitempty"` } // SIPValidator validates SIP messages against RFC 3261 and security rules type SIPValidator struct { config ValidationConfig disabledSet map[string]bool logger *zap.Logger mu sync.RWMutex stats validatorStats } type validatorStats struct { totalValidated int64 totalValid int64 totalInvalid int64 violationsByRule map[string]int64 mu sync.Mutex } // DefaultValidationConfig returns sensible defaults func DefaultValidationConfig() ValidationConfig { return ValidationConfig{ Enabled: true, Mode: ValidationModePermissive, MaxMessageSize: 65535, BanOnNullBytes: true, BanOnBinaryInjection: true, DisabledRules: []string{}, } } // Global validator instance var ( globalValidator *SIPValidator validatorMu sync.RWMutex ) // GetValidator returns the global validator instance func GetValidator(logger *zap.Logger) *SIPValidator { validatorMu.Lock() defer validatorMu.Unlock() if globalValidator == nil { globalValidator = NewSIPValidator(logger, DefaultValidationConfig()) } return globalValidator } // SetValidationConfig updates the global validator configuration func SetValidationConfig(config ValidationConfig) { validatorMu.Lock() defer validatorMu.Unlock() if globalValidator == nil { globalValidator = &SIPValidator{ config: config, disabledSet: make(map[string]bool), logger: zap.NewNop(), stats: validatorStats{ violationsByRule: make(map[string]int64), }, } } else { globalValidator.config = config } // Rebuild disabled rules set globalValidator.disabledSet = make(map[string]bool) for _, rule := range config.DisabledRules { globalValidator.disabledSet[rule] = true } } // NewSIPValidator creates a new validator func NewSIPValidator(logger *zap.Logger, config ValidationConfig) *SIPValidator { disabledSet := make(map[string]bool) for _, rule := range config.DisabledRules { disabledSet[rule] = true } return &SIPValidator{ config: config, disabledSet: disabledSet, logger: logger, stats: validatorStats{ violationsByRule: make(map[string]int64), }, } } // Validate checks a SIP message for RFC compliance and security issues func (v *SIPValidator) Validate(data []byte) ValidationResult { v.mu.RLock() config := v.config disabled := v.disabledSet v.mu.RUnlock() if !config.Enabled { return ValidationResult{Valid: true} } result := ValidationResult{Valid: true} var violations []Violation // Critical checks first (can cause immediate ban) if !disabled["null_bytes"] { if v := v.checkNullBytes(data); v != nil { violations = append(violations, *v) if config.BanOnNullBytes { result.ShouldBan = true result.BanReason = "validation_null_bytes" } } } if !disabled["binary_injection"] { if v := v.checkBinaryInjection(data); v != nil { violations = append(violations, *v) if config.BanOnBinaryInjection { result.ShouldBan = true result.BanReason = "validation_binary_injection" } } } // Size check if !disabled["oversized_message"] { if v := v.checkMessageSize(data, config.MaxMessageSize); v != nil { violations = append(violations, *v) } } // Parse headers for further validation headers := v.parseHeaders(data) // Required header checks if !disabled["missing_via"] { if v := v.checkRequiredHeader(headers, "Via", SeverityHigh); v != nil { violations = append(violations, *v) } } if !disabled["missing_from"] { if v := v.checkRequiredHeader(headers, "From", SeverityHigh); v != nil { violations = append(violations, *v) } } if !disabled["missing_to"] { if v := v.checkRequiredHeader(headers, "To", SeverityHigh); v != nil { violations = append(violations, *v) } } if !disabled["missing_call_id"] { if v := v.checkRequiredHeader(headers, "Call-ID", SeverityHigh); v != nil { violations = append(violations, *v) } } if !disabled["missing_cseq"] { if v := v.checkRequiredHeader(headers, "CSeq", SeverityHigh); v != nil { violations = append(violations, *v) } } if !disabled["missing_max_forwards"] { if v := v.checkRequiredHeader(headers, "Max-Forwards", SeverityMedium); v != nil { violations = append(violations, *v) } } // CSeq validation if !disabled["cseq_out_of_range"] { if v := v.checkCSeq(headers); v != nil { violations = append(violations, *v) } } // Content-Length validation if !disabled["content_length_mismatch"] { if v := v.checkContentLength(headers, data); v != nil { violations = append(violations, *v) } } // Via branch validation (strict/paranoid mode) if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid { if !disabled["via_invalid_branch"] { if v := v.checkViaBranch(headers); v != nil { violations = append(violations, *v) } } } // Request-URI validation (strict/paranoid mode) if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid { if !disabled["invalid_request_uri"] { if v := v.checkRequestURI(data); v != nil { violations = append(violations, *v) } } } // Paranoid mode extra checks if config.Mode == ValidationModeParanoid { if !disabled["suspicious_headers"] { violations = append(violations, v.checkSuspiciousHeaders(headers)...) } } // Process violations if len(violations) > 0 { result.Valid = false result.Violations = violations // Record stats v.stats.mu.Lock() v.stats.totalValidated++ v.stats.totalInvalid++ for _, viol := range violations { v.stats.violationsByRule[viol.Rule]++ } v.stats.mu.Unlock() // Determine if violations warrant a ban (in strict/paranoid mode) if !result.ShouldBan && (config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid) { for _, viol := range violations { if viol.Severity == SeverityCritical || viol.Severity == SeverityHigh { result.ShouldBan = true result.BanReason = "validation_" + viol.Rule break } } } } else { v.stats.mu.Lock() v.stats.totalValidated++ v.stats.totalValid++ v.stats.mu.Unlock() } return result } // checkNullBytes detects null byte injection attacks func (v *SIPValidator) checkNullBytes(data []byte) *Violation { if bytes.Contains(data, []byte{0x00}) { return &Violation{ Rule: "null_bytes", Severity: SeverityCritical, Message: "Null byte injection detected", Details: "SIP message contains null bytes which may indicate an injection attack", } } return nil } // checkBinaryInjection detects non-ASCII binary data in headers func (v *SIPValidator) checkBinaryInjection(data []byte) *Violation { // Split headers from body parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2) headerSection := parts[0] // Check for non-printable ASCII in headers (except CR, LF, TAB) for i, b := range headerSection { if b < 0x09 || (b > 0x0D && b < 0x20) || b > 0x7E { // Allow UTF-8 in display names, but flag truly binary data if b < 0x80 { return &Violation{ Rule: "binary_injection", Severity: SeverityCritical, Message: "Binary data injection detected", Details: "Non-printable character 0x" + strconv.FormatInt(int64(b), 16) + " at position " + strconv.Itoa(i), } } } } return nil } // checkMessageSize validates message doesn't exceed limits func (v *SIPValidator) checkMessageSize(data []byte, maxSize int) *Violation { if len(data) > maxSize { return &Violation{ Rule: "oversized_message", Severity: SeverityHigh, Message: "Message exceeds maximum size", Details: "Size " + strconv.Itoa(len(data)) + " exceeds limit " + strconv.Itoa(maxSize), } } return nil } // parseHeaders extracts headers from SIP message func (v *SIPValidator) parseHeaders(data []byte) map[string][]string { headers := make(map[string][]string) // Split headers from body parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2) headerSection := string(parts[0]) lines := strings.Split(headerSection, "\r\n") // Skip request/status line for i := 1; i < len(lines); i++ { line := lines[i] if line == "" { continue } // Handle header continuation (line starting with whitespace) if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') { // Append to previous header continue } // Parse header name: value colonIdx := strings.Index(line, ":") if colonIdx > 0 { name := strings.TrimSpace(line[:colonIdx]) value := strings.TrimSpace(line[colonIdx+1:]) // Normalize header name (compact forms) name = normalizeHeaderName(name) headers[name] = append(headers[name], value) } } return headers } // normalizeHeaderName handles SIP compact header forms func normalizeHeaderName(name string) string { // Map compact forms to full names compactForms := map[string]string{ "i": "Call-ID", "m": "Contact", "e": "Content-Encoding", "l": "Content-Length", "c": "Content-Type", "f": "From", "s": "Subject", "k": "Supported", "t": "To", "v": "Via", } lower := strings.ToLower(name) if full, ok := compactForms[lower]; ok { return full } // Normalize case for common headers switch lower { case "via": return "Via" case "from": return "From" case "to": return "To" case "call-id": return "Call-ID" case "cseq": return "CSeq" case "max-forwards": return "Max-Forwards" case "content-length": return "Content-Length" case "contact": return "Contact" default: return name } } // checkRequiredHeader validates a required header is present func (v *SIPValidator) checkRequiredHeader(headers map[string][]string, name string, severity ViolationSeverity) *Violation { if values, ok := headers[name]; !ok || len(values) == 0 || values[0] == "" { return &Violation{ Rule: "missing_" + strings.ToLower(strings.ReplaceAll(name, "-", "_")), Severity: severity, Message: "Required header missing: " + name, } } return nil } // checkCSeq validates CSeq header format and range func (v *SIPValidator) checkCSeq(headers map[string][]string) *Violation { values, ok := headers["CSeq"] if !ok || len(values) == 0 { return nil // Missing CSeq handled by required header check } cseq := values[0] parts := strings.Fields(cseq) if len(parts) < 2 { return &Violation{ Rule: "cseq_invalid_format", Severity: SeverityHigh, Message: "Invalid CSeq format", Details: "Expected 'sequence method', got: " + cseq, } } seq, err := strconv.ParseInt(parts[0], 10, 64) if err != nil { return &Violation{ Rule: "cseq_invalid_sequence", Severity: SeverityHigh, Message: "CSeq sequence number invalid", Details: "Could not parse: " + parts[0], } } // RFC 3261: CSeq must be 0 to 2^31-1 if seq < 0 || seq > 2147483647 { return &Violation{ Rule: "cseq_out_of_range", Severity: SeverityMedium, Message: "CSeq sequence number out of range", Details: "Value " + parts[0] + " outside valid range 0-2147483647", } } return nil } // checkContentLength validates Content-Length matches actual body func (v *SIPValidator) checkContentLength(headers map[string][]string, data []byte) *Violation { values, ok := headers["Content-Length"] if !ok || len(values) == 0 { return nil // Optional header } claimLength, err := strconv.Atoi(values[0]) if err != nil { return &Violation{ Rule: "content_length_invalid", Severity: SeverityHigh, Message: "Invalid Content-Length value", Details: "Could not parse: " + values[0], } } // Find actual body parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2) actualLength := 0 if len(parts) > 1 { actualLength = len(parts[1]) } // Allow some tolerance for line ending differences if claimLength != actualLength { // Check if difference is just CRLF vs LF if abs(claimLength-actualLength) > 2 { return &Violation{ Rule: "content_length_mismatch", Severity: SeverityHigh, Message: "Content-Length doesn't match body size", Details: "Claimed " + values[0] + ", actual " + strconv.Itoa(actualLength), } } } return nil } // checkViaBranch validates Via branch parameter format func (v *SIPValidator) checkViaBranch(headers map[string][]string) *Violation { values, ok := headers["Via"] if !ok || len(values) == 0 { return nil } via := values[0] // RFC 3261: branch must start with "z9hG4bK" for 3261-compliant UAs branchPattern := regexp.MustCompile(`branch=([^;,\s]+)`) matches := branchPattern.FindStringSubmatch(via) if len(matches) > 1 { branch := matches[1] if !strings.HasPrefix(branch, "z9hG4bK") { return &Violation{ Rule: "via_invalid_branch", Severity: SeverityLow, Message: "Via branch doesn't follow RFC 3261 format", Details: "Branch should start with 'z9hG4bK', got: " + branch, } } } return nil } // checkRequestURI validates Request-URI format func (v *SIPValidator) checkRequestURI(data []byte) *Violation { // Get first line idx := bytes.Index(data, []byte("\r\n")) if idx < 0 { return &Violation{ Rule: "invalid_request_line", Severity: SeverityHigh, Message: "Could not parse request line", } } requestLine := string(data[:idx]) parts := strings.Fields(requestLine) if len(parts) < 3 { return &Violation{ Rule: "invalid_request_line", Severity: SeverityHigh, Message: "Malformed request line", Details: requestLine, } } // Check if it's a response (starts with SIP/2.0) if strings.HasPrefix(parts[0], "SIP/") { return nil // Skip URI validation for responses } uri := parts[1] // Basic SIP URI validation if !strings.HasPrefix(strings.ToLower(uri), "sip:") && !strings.HasPrefix(strings.ToLower(uri), "sips:") && !strings.HasPrefix(strings.ToLower(uri), "tel:") { return &Violation{ Rule: "invalid_request_uri", Severity: SeverityHigh, Message: "Invalid Request-URI scheme", Details: "Expected sip:, sips:, or tel: URI, got: " + uri, } } return nil } // checkSuspiciousHeaders detects potentially malicious header patterns func (v *SIPValidator) checkSuspiciousHeaders(headers map[string][]string) []Violation { var violations []Violation // Check for extremely long header values for name, values := range headers { for _, val := range values { if len(val) > 1024 { violations = append(violations, Violation{ Rule: "suspicious_long_header", Severity: SeverityMedium, Message: "Unusually long header value", Details: name + " has " + strconv.Itoa(len(val)) + " bytes", }) } } } // Check for excessive number of Via headers (potential loop) if vias, ok := headers["Via"]; ok && len(vias) > 70 { violations = append(violations, Violation{ Rule: "suspicious_via_count", Severity: SeverityMedium, Message: "Excessive Via headers", Details: strconv.Itoa(len(vias)) + " Via headers (potential loop)", }) } // Check for SQL injection patterns in headers sqlPatterns := []string{ "' OR '", "' OR 1=1", "'; DROP", "UNION SELECT", "--", } for name, values := range headers { for _, val := range values { upperVal := strings.ToUpper(val) for _, pattern := range sqlPatterns { if strings.Contains(upperVal, pattern) { violations = append(violations, Violation{ Rule: "suspicious_sql_injection", Severity: SeverityHigh, Message: "Potential SQL injection in header", Details: name + " contains suspicious pattern", }) break } } } } return violations } // GetStats returns validator statistics func (v *SIPValidator) GetStats() map[string]interface{} { v.stats.mu.Lock() defer v.stats.mu.Unlock() violations := make(map[string]int64) for rule, count := range v.stats.violationsByRule { violations[rule] = count } return map[string]interface{}{ "total_validated": v.stats.totalValidated, "total_valid": v.stats.totalValid, "total_invalid": v.stats.totalInvalid, "violations_by_rule": violations, "config_mode": v.config.Mode, "config_enabled": v.config.Enabled, } } // IsEnabled returns whether validation is enabled func (v *SIPValidator) IsEnabled() bool { v.mu.RLock() defer v.mu.RUnlock() return v.config.Enabled } // abs returns absolute value of an int func abs(n int) int { if n < 0 { return -n } return n } // Cleanup resets old statistics (called periodically) func (v *SIPValidator) Cleanup(maxAge time.Duration) { // For now, validation stats don't need time-based cleanup // Could add per-IP validation tracking in the future }