Add SIP message validation feature
Implements RFC 3261 compliance checking and security validation:
- Three validation modes: permissive (default), strict, paranoid
- Critical checks: null bytes, binary injection (immediate ban)
- RFC compliance: required headers (Via, From, To, Call-ID, CSeq, Max-Forwards)
- Format validation: CSeq range, Content-Length, Via branch format
- Paranoid mode: SQL injection patterns, excessive headers, long values
- Compact header form support (v, f, t, i, l, etc.)
Caddyfile configuration:
validation {
enabled true
mode permissive
max_message_size 65535
ban_on_null_bytes true
ban_on_binary_injection true
disabled_rules via_invalid_branch
}
New Prometheus metrics:
- sip_guardian_validation_violations_total{rule}
- sip_guardian_validation_results_total{result}
- sip_guardian_message_size_bytes (histogram)
Includes comprehensive unit tests covering all validation scenarios.
This commit is contained in:
parent
95a794ba69
commit
976fdf53a5
109
l4handler.go
109
l4handler.go
@ -153,7 +153,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
|
||||
|
||||
// Read data from the connection for suspicious pattern detection
|
||||
// caddy-l4 replays prefetched data on read, so we can read the full message here
|
||||
buf := make([]byte, 1024)
|
||||
buf := make([]byte, 4096) // Larger buffer for validation
|
||||
n, err := cx.Read(buf)
|
||||
if n > 0 {
|
||||
buf = buf[:n]
|
||||
@ -162,6 +162,61 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error {
|
||||
zap.Int("bytes", n),
|
||||
)
|
||||
|
||||
// Record message size metric
|
||||
if enableMetrics {
|
||||
RecordMessageSize(n)
|
||||
}
|
||||
|
||||
// Validate SIP message structure and content
|
||||
validator := GetValidator(h.logger)
|
||||
if validator.IsEnabled() {
|
||||
validationResult := validator.Validate(buf)
|
||||
|
||||
// Record metrics for violations
|
||||
if enableMetrics {
|
||||
for _, v := range validationResult.Violations {
|
||||
RecordValidationViolation(v.Rule)
|
||||
}
|
||||
if validationResult.Valid {
|
||||
RecordValidationResult("valid")
|
||||
} else if validationResult.ShouldBan {
|
||||
RecordValidationResult("ban")
|
||||
} else {
|
||||
RecordValidationResult("invalid")
|
||||
}
|
||||
}
|
||||
|
||||
if !validationResult.Valid {
|
||||
h.logger.Warn("SIP validation failed",
|
||||
zap.String("ip", host),
|
||||
zap.Int("violation_count", len(validationResult.Violations)),
|
||||
zap.Bool("should_ban", validationResult.ShouldBan),
|
||||
)
|
||||
|
||||
// Log individual violations at debug level
|
||||
for _, v := range validationResult.Violations {
|
||||
h.logger.Debug("Validation violation",
|
||||
zap.String("ip", host),
|
||||
zap.String("rule", v.Rule),
|
||||
zap.String("severity", string(v.Severity)),
|
||||
zap.String("message", v.Message),
|
||||
)
|
||||
}
|
||||
|
||||
if validationResult.ShouldBan {
|
||||
if enableMetrics {
|
||||
RecordConnection("validation_blocked")
|
||||
}
|
||||
h.guardian.RecordFailure(host, validationResult.BanReason)
|
||||
return cx.Close()
|
||||
}
|
||||
|
||||
// In strict/paranoid mode, any violation rejects the message
|
||||
// but we already counted it toward ban threshold above via RecordFailure
|
||||
// For now in permissive mode, log and continue
|
||||
}
|
||||
}
|
||||
|
||||
// Extract SIP method for rate limiting
|
||||
method := ExtractSIPMethod(buf)
|
||||
if method != "" {
|
||||
@ -362,6 +417,14 @@ func (m *SIPMatcher) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
// ban_time 2h
|
||||
// exempt_extensions 100 200 9999
|
||||
// }
|
||||
// validation {
|
||||
// enabled true
|
||||
// mode permissive # permissive, strict, paranoid
|
||||
// max_message_size 65535
|
||||
// ban_on_null_bytes true
|
||||
// ban_on_binary_injection true
|
||||
// disabled_rules via_invalid_branch cseq_out_of_range
|
||||
// }
|
||||
// webhook http://example.com/hook { ... }
|
||||
// }
|
||||
//
|
||||
@ -507,6 +570,50 @@ func (h *SIPHandler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
}
|
||||
}
|
||||
|
||||
case "validation":
|
||||
h.Validation = &ValidationConfig{}
|
||||
for innerNesting := d.Nesting(); d.NextBlock(innerNesting); {
|
||||
switch d.Val() {
|
||||
case "enabled":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.Validation.Enabled = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
|
||||
case "mode":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
mode := ValidationMode(d.Val())
|
||||
if mode != ValidationModePermissive && mode != ValidationModeStrict && mode != ValidationModeParanoid {
|
||||
return d.Errf("invalid validation mode: %s (must be permissive, strict, or paranoid)", d.Val())
|
||||
}
|
||||
h.Validation.Mode = mode
|
||||
case "max_message_size":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
var val int
|
||||
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
|
||||
return d.Errf("invalid max_message_size: %v", err)
|
||||
}
|
||||
h.Validation.MaxMessageSize = val
|
||||
case "ban_on_null_bytes":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.Validation.BanOnNullBytes = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
|
||||
case "ban_on_binary_injection":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
}
|
||||
h.Validation.BanOnBinaryInjection = d.Val() == "true" || d.Val() == "yes" || d.Val() == "on"
|
||||
case "disabled_rules":
|
||||
h.Validation.DisabledRules = d.RemainingArgs()
|
||||
default:
|
||||
return d.Errf("unknown validation directive: %s", d.Val())
|
||||
}
|
||||
}
|
||||
|
||||
case "webhook":
|
||||
if !d.NextArg() {
|
||||
return d.ArgErr()
|
||||
|
||||
46
metrics.go
46
metrics.go
@ -121,6 +121,34 @@ var (
|
||||
Buckets: []float64{5, 10, 15, 20, 30, 50, 100},
|
||||
},
|
||||
)
|
||||
|
||||
// Validation metrics
|
||||
sipValidationViolations = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: "sip_guardian",
|
||||
Name: "validation_violations_total",
|
||||
Help: "Total validation violations detected by rule",
|
||||
},
|
||||
[]string{"rule"}, // "null_bytes", "binary_injection", "missing_via", etc.
|
||||
)
|
||||
|
||||
sipValidationResults = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: "sip_guardian",
|
||||
Name: "validation_results_total",
|
||||
Help: "Total validation results by outcome",
|
||||
},
|
||||
[]string{"result"}, // "valid", "invalid", "ban"
|
||||
)
|
||||
|
||||
sipMessageSizeBytes = prometheus.NewHistogram(
|
||||
prometheus.HistogramOpts{
|
||||
Namespace: "sip_guardian",
|
||||
Name: "message_size_bytes",
|
||||
Help: "Distribution of SIP message sizes in bytes",
|
||||
Buckets: []float64{100, 500, 1000, 2000, 5000, 10000, 20000, 50000, 65535},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
// metricsRegistered tracks if we've registered with Prometheus
|
||||
@ -146,6 +174,9 @@ func RegisterMetrics() {
|
||||
sipEnumerationDetections,
|
||||
sipEnumerationTrackedIPs,
|
||||
sipEnumerationUniqueExtensions,
|
||||
sipValidationViolations,
|
||||
sipValidationResults,
|
||||
sipMessageSizeBytes,
|
||||
)
|
||||
}
|
||||
|
||||
@ -213,6 +244,21 @@ func RecordEnumerationExtensions(count int) {
|
||||
sipEnumerationUniqueExtensions.Observe(float64(count))
|
||||
}
|
||||
|
||||
// RecordValidationViolation records a validation violation by rule name
|
||||
func RecordValidationViolation(rule string) {
|
||||
sipValidationViolations.WithLabelValues(rule).Inc()
|
||||
}
|
||||
|
||||
// RecordValidationResult records a validation result (valid, invalid, ban)
|
||||
func RecordValidationResult(result string) {
|
||||
sipValidationResults.WithLabelValues(result).Inc()
|
||||
}
|
||||
|
||||
// RecordMessageSize records the size of a SIP message in bytes
|
||||
func RecordMessageSize(bytes int) {
|
||||
sipMessageSizeBytes.Observe(float64(bytes))
|
||||
}
|
||||
|
||||
// MetricsHandler provides a Prometheus metrics endpoint for SIP Guardian
|
||||
type MetricsHandler struct {
|
||||
// Path prefix for metrics (default: /metrics)
|
||||
|
||||
@ -55,6 +55,9 @@ type SIPGuardian struct {
|
||||
// Enumeration detection configuration
|
||||
Enumeration *EnumerationConfig `json:"enumeration,omitempty"`
|
||||
|
||||
// Validation configuration
|
||||
Validation *ValidationConfig `json:"validation,omitempty"`
|
||||
|
||||
// Runtime state
|
||||
logger *zap.Logger
|
||||
bannedIPs map[string]*BanEntry
|
||||
@ -162,6 +165,16 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
||||
)
|
||||
}
|
||||
|
||||
// Initialize validation with config if specified
|
||||
if g.Validation != nil {
|
||||
SetValidationConfig(*g.Validation)
|
||||
g.logger.Info("SIP validation configured",
|
||||
zap.String("mode", string(g.Validation.Mode)),
|
||||
zap.Bool("enabled", g.Validation.Enabled),
|
||||
zap.Int("max_message_size", g.Validation.MaxMessageSize),
|
||||
)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go g.cleanupLoop(ctx)
|
||||
|
||||
@ -174,6 +187,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
||||
zap.Bool("geoip_enabled", g.geoIP != nil),
|
||||
zap.Int("webhook_count", len(g.Webhooks)),
|
||||
zap.Bool("enumeration_enabled", g.Enumeration != nil),
|
||||
zap.Bool("validation_enabled", g.Validation != nil && g.Validation.Enabled),
|
||||
)
|
||||
|
||||
return nil
|
||||
|
||||
709
validation.go
Normal file
709
validation.go
Normal file
@ -0,0 +1,709 @@
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ValidationConfig holds configuration for SIP message validation
|
||||
type ValidationConfig struct {
|
||||
// Enabled determines if validation is active
|
||||
Enabled bool `json:"enabled,omitempty"`
|
||||
|
||||
// Mode determines validation strictness: permissive, strict, paranoid
|
||||
Mode ValidationMode `json:"mode,omitempty"`
|
||||
|
||||
// MaxMessageSize limits SIP message size (default: 65535)
|
||||
MaxMessageSize int `json:"max_message_size,omitempty"`
|
||||
|
||||
// BanOnNullBytes causes immediate ban for null byte injection
|
||||
BanOnNullBytes bool `json:"ban_on_null_bytes,omitempty"`
|
||||
|
||||
// BanOnBinaryInjection causes immediate ban for binary data injection
|
||||
BanOnBinaryInjection bool `json:"ban_on_binary_injection,omitempty"`
|
||||
|
||||
// DisabledRules lists rules to skip
|
||||
DisabledRules []string `json:"disabled_rules,omitempty"`
|
||||
}
|
||||
|
||||
// ValidationMode determines validation strictness
|
||||
type ValidationMode string
|
||||
|
||||
const (
|
||||
// ValidationModePermissive logs violations but only blocks critical attacks
|
||||
ValidationModePermissive ValidationMode = "permissive"
|
||||
|
||||
// ValidationModeStrict enforces RFC 3261 compliance
|
||||
ValidationModeStrict ValidationMode = "strict"
|
||||
|
||||
// ValidationModeParanoid applies extra security heuristics
|
||||
ValidationModeParanoid ValidationMode = "paranoid"
|
||||
)
|
||||
|
||||
// ViolationSeverity indicates how serious a violation is
|
||||
type ViolationSeverity string
|
||||
|
||||
const (
|
||||
SeverityCritical ViolationSeverity = "critical" // Immediate ban
|
||||
SeverityHigh ViolationSeverity = "high" // Count toward ban
|
||||
SeverityMedium ViolationSeverity = "medium" // Reject only
|
||||
SeverityLow ViolationSeverity = "low" // Log only
|
||||
)
|
||||
|
||||
// Violation represents a single validation rule violation
|
||||
type Violation struct {
|
||||
Rule string `json:"rule"`
|
||||
Severity ViolationSeverity `json:"severity"`
|
||||
Message string `json:"message"`
|
||||
Details string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ValidationResult contains the result of validating a SIP message
|
||||
type ValidationResult struct {
|
||||
Valid bool `json:"valid"`
|
||||
Violations []Violation `json:"violations,omitempty"`
|
||||
ShouldBan bool `json:"should_ban"`
|
||||
BanReason string `json:"ban_reason,omitempty"`
|
||||
}
|
||||
|
||||
// SIPValidator validates SIP messages against RFC 3261 and security rules
|
||||
type SIPValidator struct {
|
||||
config ValidationConfig
|
||||
disabledSet map[string]bool
|
||||
logger *zap.Logger
|
||||
mu sync.RWMutex
|
||||
stats validatorStats
|
||||
}
|
||||
|
||||
type validatorStats struct {
|
||||
totalValidated int64
|
||||
totalValid int64
|
||||
totalInvalid int64
|
||||
violationsByRule map[string]int64
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// DefaultValidationConfig returns sensible defaults
|
||||
func DefaultValidationConfig() ValidationConfig {
|
||||
return ValidationConfig{
|
||||
Enabled: true,
|
||||
Mode: ValidationModePermissive,
|
||||
MaxMessageSize: 65535,
|
||||
BanOnNullBytes: true,
|
||||
BanOnBinaryInjection: true,
|
||||
DisabledRules: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
// Global validator instance
|
||||
var (
|
||||
globalValidator *SIPValidator
|
||||
validatorMu sync.RWMutex
|
||||
)
|
||||
|
||||
// GetValidator returns the global validator instance
|
||||
func GetValidator(logger *zap.Logger) *SIPValidator {
|
||||
validatorMu.Lock()
|
||||
defer validatorMu.Unlock()
|
||||
|
||||
if globalValidator == nil {
|
||||
globalValidator = NewSIPValidator(logger, DefaultValidationConfig())
|
||||
}
|
||||
return globalValidator
|
||||
}
|
||||
|
||||
// SetValidationConfig updates the global validator configuration
|
||||
func SetValidationConfig(config ValidationConfig) {
|
||||
validatorMu.Lock()
|
||||
defer validatorMu.Unlock()
|
||||
|
||||
if globalValidator == nil {
|
||||
globalValidator = &SIPValidator{
|
||||
config: config,
|
||||
disabledSet: make(map[string]bool),
|
||||
logger: zap.NewNop(),
|
||||
stats: validatorStats{
|
||||
violationsByRule: make(map[string]int64),
|
||||
},
|
||||
}
|
||||
} else {
|
||||
globalValidator.config = config
|
||||
}
|
||||
|
||||
// Rebuild disabled rules set
|
||||
globalValidator.disabledSet = make(map[string]bool)
|
||||
for _, rule := range config.DisabledRules {
|
||||
globalValidator.disabledSet[rule] = true
|
||||
}
|
||||
}
|
||||
|
||||
// NewSIPValidator creates a new validator
|
||||
func NewSIPValidator(logger *zap.Logger, config ValidationConfig) *SIPValidator {
|
||||
disabledSet := make(map[string]bool)
|
||||
for _, rule := range config.DisabledRules {
|
||||
disabledSet[rule] = true
|
||||
}
|
||||
|
||||
return &SIPValidator{
|
||||
config: config,
|
||||
disabledSet: disabledSet,
|
||||
logger: logger,
|
||||
stats: validatorStats{
|
||||
violationsByRule: make(map[string]int64),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks a SIP message for RFC compliance and security issues
|
||||
func (v *SIPValidator) Validate(data []byte) ValidationResult {
|
||||
v.mu.RLock()
|
||||
config := v.config
|
||||
disabled := v.disabledSet
|
||||
v.mu.RUnlock()
|
||||
|
||||
if !config.Enabled {
|
||||
return ValidationResult{Valid: true}
|
||||
}
|
||||
|
||||
result := ValidationResult{Valid: true}
|
||||
var violations []Violation
|
||||
|
||||
// Critical checks first (can cause immediate ban)
|
||||
if !disabled["null_bytes"] {
|
||||
if v := v.checkNullBytes(data); v != nil {
|
||||
violations = append(violations, *v)
|
||||
if config.BanOnNullBytes {
|
||||
result.ShouldBan = true
|
||||
result.BanReason = "validation_null_bytes"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["binary_injection"] {
|
||||
if v := v.checkBinaryInjection(data); v != nil {
|
||||
violations = append(violations, *v)
|
||||
if config.BanOnBinaryInjection {
|
||||
result.ShouldBan = true
|
||||
result.BanReason = "validation_binary_injection"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Size check
|
||||
if !disabled["oversized_message"] {
|
||||
if v := v.checkMessageSize(data, config.MaxMessageSize); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
// Parse headers for further validation
|
||||
headers := v.parseHeaders(data)
|
||||
|
||||
// Required header checks
|
||||
if !disabled["missing_via"] {
|
||||
if v := v.checkRequiredHeader(headers, "Via", SeverityHigh); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["missing_from"] {
|
||||
if v := v.checkRequiredHeader(headers, "From", SeverityHigh); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["missing_to"] {
|
||||
if v := v.checkRequiredHeader(headers, "To", SeverityHigh); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["missing_call_id"] {
|
||||
if v := v.checkRequiredHeader(headers, "Call-ID", SeverityHigh); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["missing_cseq"] {
|
||||
if v := v.checkRequiredHeader(headers, "CSeq", SeverityHigh); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
if !disabled["missing_max_forwards"] {
|
||||
if v := v.checkRequiredHeader(headers, "Max-Forwards", SeverityMedium); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
// CSeq validation
|
||||
if !disabled["cseq_out_of_range"] {
|
||||
if v := v.checkCSeq(headers); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
// Content-Length validation
|
||||
if !disabled["content_length_mismatch"] {
|
||||
if v := v.checkContentLength(headers, data); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
|
||||
// Via branch validation (strict/paranoid mode)
|
||||
if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid {
|
||||
if !disabled["via_invalid_branch"] {
|
||||
if v := v.checkViaBranch(headers); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Request-URI validation (strict/paranoid mode)
|
||||
if config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid {
|
||||
if !disabled["invalid_request_uri"] {
|
||||
if v := v.checkRequestURI(data); v != nil {
|
||||
violations = append(violations, *v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Paranoid mode extra checks
|
||||
if config.Mode == ValidationModeParanoid {
|
||||
if !disabled["suspicious_headers"] {
|
||||
violations = append(violations, v.checkSuspiciousHeaders(headers)...)
|
||||
}
|
||||
}
|
||||
|
||||
// Process violations
|
||||
if len(violations) > 0 {
|
||||
result.Valid = false
|
||||
result.Violations = violations
|
||||
|
||||
// Record stats
|
||||
v.stats.mu.Lock()
|
||||
v.stats.totalValidated++
|
||||
v.stats.totalInvalid++
|
||||
for _, viol := range violations {
|
||||
v.stats.violationsByRule[viol.Rule]++
|
||||
}
|
||||
v.stats.mu.Unlock()
|
||||
|
||||
// Determine if violations warrant a ban (in strict/paranoid mode)
|
||||
if !result.ShouldBan && (config.Mode == ValidationModeStrict || config.Mode == ValidationModeParanoid) {
|
||||
for _, viol := range violations {
|
||||
if viol.Severity == SeverityCritical || viol.Severity == SeverityHigh {
|
||||
result.ShouldBan = true
|
||||
result.BanReason = "validation_" + viol.Rule
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v.stats.mu.Lock()
|
||||
v.stats.totalValidated++
|
||||
v.stats.totalValid++
|
||||
v.stats.mu.Unlock()
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// checkNullBytes detects null byte injection attacks
|
||||
func (v *SIPValidator) checkNullBytes(data []byte) *Violation {
|
||||
if bytes.Contains(data, []byte{0x00}) {
|
||||
return &Violation{
|
||||
Rule: "null_bytes",
|
||||
Severity: SeverityCritical,
|
||||
Message: "Null byte injection detected",
|
||||
Details: "SIP message contains null bytes which may indicate an injection attack",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBinaryInjection detects non-ASCII binary data in headers
|
||||
func (v *SIPValidator) checkBinaryInjection(data []byte) *Violation {
|
||||
// Split headers from body
|
||||
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
|
||||
headerSection := parts[0]
|
||||
|
||||
// Check for non-printable ASCII in headers (except CR, LF, TAB)
|
||||
for i, b := range headerSection {
|
||||
if b < 0x09 || (b > 0x0D && b < 0x20) || b > 0x7E {
|
||||
// Allow UTF-8 in display names, but flag truly binary data
|
||||
if b < 0x80 {
|
||||
return &Violation{
|
||||
Rule: "binary_injection",
|
||||
Severity: SeverityCritical,
|
||||
Message: "Binary data injection detected",
|
||||
Details: "Non-printable character 0x" + strconv.FormatInt(int64(b), 16) + " at position " + strconv.Itoa(i),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkMessageSize validates message doesn't exceed limits
|
||||
func (v *SIPValidator) checkMessageSize(data []byte, maxSize int) *Violation {
|
||||
if len(data) > maxSize {
|
||||
return &Violation{
|
||||
Rule: "oversized_message",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Message exceeds maximum size",
|
||||
Details: "Size " + strconv.Itoa(len(data)) + " exceeds limit " + strconv.Itoa(maxSize),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseHeaders extracts headers from SIP message
|
||||
func (v *SIPValidator) parseHeaders(data []byte) map[string][]string {
|
||||
headers := make(map[string][]string)
|
||||
|
||||
// Split headers from body
|
||||
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
|
||||
headerSection := string(parts[0])
|
||||
|
||||
lines := strings.Split(headerSection, "\r\n")
|
||||
// Skip request/status line
|
||||
for i := 1; i < len(lines); i++ {
|
||||
line := lines[i]
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle header continuation (line starting with whitespace)
|
||||
if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') {
|
||||
// Append to previous header
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse header name: value
|
||||
colonIdx := strings.Index(line, ":")
|
||||
if colonIdx > 0 {
|
||||
name := strings.TrimSpace(line[:colonIdx])
|
||||
value := strings.TrimSpace(line[colonIdx+1:])
|
||||
|
||||
// Normalize header name (compact forms)
|
||||
name = normalizeHeaderName(name)
|
||||
headers[name] = append(headers[name], value)
|
||||
}
|
||||
}
|
||||
|
||||
return headers
|
||||
}
|
||||
|
||||
// normalizeHeaderName handles SIP compact header forms
|
||||
func normalizeHeaderName(name string) string {
|
||||
// Map compact forms to full names
|
||||
compactForms := map[string]string{
|
||||
"i": "Call-ID",
|
||||
"m": "Contact",
|
||||
"e": "Content-Encoding",
|
||||
"l": "Content-Length",
|
||||
"c": "Content-Type",
|
||||
"f": "From",
|
||||
"s": "Subject",
|
||||
"k": "Supported",
|
||||
"t": "To",
|
||||
"v": "Via",
|
||||
}
|
||||
|
||||
lower := strings.ToLower(name)
|
||||
if full, ok := compactForms[lower]; ok {
|
||||
return full
|
||||
}
|
||||
|
||||
// Normalize case for common headers
|
||||
switch lower {
|
||||
case "via":
|
||||
return "Via"
|
||||
case "from":
|
||||
return "From"
|
||||
case "to":
|
||||
return "To"
|
||||
case "call-id":
|
||||
return "Call-ID"
|
||||
case "cseq":
|
||||
return "CSeq"
|
||||
case "max-forwards":
|
||||
return "Max-Forwards"
|
||||
case "content-length":
|
||||
return "Content-Length"
|
||||
case "contact":
|
||||
return "Contact"
|
||||
default:
|
||||
return name
|
||||
}
|
||||
}
|
||||
|
||||
// checkRequiredHeader validates a required header is present
|
||||
func (v *SIPValidator) checkRequiredHeader(headers map[string][]string, name string, severity ViolationSeverity) *Violation {
|
||||
if values, ok := headers[name]; !ok || len(values) == 0 || values[0] == "" {
|
||||
return &Violation{
|
||||
Rule: "missing_" + strings.ToLower(strings.ReplaceAll(name, "-", "_")),
|
||||
Severity: severity,
|
||||
Message: "Required header missing: " + name,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkCSeq validates CSeq header format and range
|
||||
func (v *SIPValidator) checkCSeq(headers map[string][]string) *Violation {
|
||||
values, ok := headers["CSeq"]
|
||||
if !ok || len(values) == 0 {
|
||||
return nil // Missing CSeq handled by required header check
|
||||
}
|
||||
|
||||
cseq := values[0]
|
||||
parts := strings.Fields(cseq)
|
||||
if len(parts) < 2 {
|
||||
return &Violation{
|
||||
Rule: "cseq_invalid_format",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Invalid CSeq format",
|
||||
Details: "Expected 'sequence method', got: " + cseq,
|
||||
}
|
||||
}
|
||||
|
||||
seq, err := strconv.ParseInt(parts[0], 10, 64)
|
||||
if err != nil {
|
||||
return &Violation{
|
||||
Rule: "cseq_invalid_sequence",
|
||||
Severity: SeverityHigh,
|
||||
Message: "CSeq sequence number invalid",
|
||||
Details: "Could not parse: " + parts[0],
|
||||
}
|
||||
}
|
||||
|
||||
// RFC 3261: CSeq must be 0 to 2^31-1
|
||||
if seq < 0 || seq > 2147483647 {
|
||||
return &Violation{
|
||||
Rule: "cseq_out_of_range",
|
||||
Severity: SeverityMedium,
|
||||
Message: "CSeq sequence number out of range",
|
||||
Details: "Value " + parts[0] + " outside valid range 0-2147483647",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkContentLength validates Content-Length matches actual body
|
||||
func (v *SIPValidator) checkContentLength(headers map[string][]string, data []byte) *Violation {
|
||||
values, ok := headers["Content-Length"]
|
||||
if !ok || len(values) == 0 {
|
||||
return nil // Optional header
|
||||
}
|
||||
|
||||
claimLength, err := strconv.Atoi(values[0])
|
||||
if err != nil {
|
||||
return &Violation{
|
||||
Rule: "content_length_invalid",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Invalid Content-Length value",
|
||||
Details: "Could not parse: " + values[0],
|
||||
}
|
||||
}
|
||||
|
||||
// Find actual body
|
||||
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
|
||||
actualLength := 0
|
||||
if len(parts) > 1 {
|
||||
actualLength = len(parts[1])
|
||||
}
|
||||
|
||||
// Allow some tolerance for line ending differences
|
||||
if claimLength != actualLength {
|
||||
// Check if difference is just CRLF vs LF
|
||||
if abs(claimLength-actualLength) > 2 {
|
||||
return &Violation{
|
||||
Rule: "content_length_mismatch",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Content-Length doesn't match body size",
|
||||
Details: "Claimed " + values[0] + ", actual " + strconv.Itoa(actualLength),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkViaBranch validates Via branch parameter format
|
||||
func (v *SIPValidator) checkViaBranch(headers map[string][]string) *Violation {
|
||||
values, ok := headers["Via"]
|
||||
if !ok || len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
via := values[0]
|
||||
// RFC 3261: branch must start with "z9hG4bK" for 3261-compliant UAs
|
||||
branchPattern := regexp.MustCompile(`branch=([^;,\s]+)`)
|
||||
matches := branchPattern.FindStringSubmatch(via)
|
||||
|
||||
if len(matches) > 1 {
|
||||
branch := matches[1]
|
||||
if !strings.HasPrefix(branch, "z9hG4bK") {
|
||||
return &Violation{
|
||||
Rule: "via_invalid_branch",
|
||||
Severity: SeverityLow,
|
||||
Message: "Via branch doesn't follow RFC 3261 format",
|
||||
Details: "Branch should start with 'z9hG4bK', got: " + branch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkRequestURI validates Request-URI format
|
||||
func (v *SIPValidator) checkRequestURI(data []byte) *Violation {
|
||||
// Get first line
|
||||
idx := bytes.Index(data, []byte("\r\n"))
|
||||
if idx < 0 {
|
||||
return &Violation{
|
||||
Rule: "invalid_request_line",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Could not parse request line",
|
||||
}
|
||||
}
|
||||
|
||||
requestLine := string(data[:idx])
|
||||
parts := strings.Fields(requestLine)
|
||||
|
||||
if len(parts) < 3 {
|
||||
return &Violation{
|
||||
Rule: "invalid_request_line",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Malformed request line",
|
||||
Details: requestLine,
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it's a response (starts with SIP/2.0)
|
||||
if strings.HasPrefix(parts[0], "SIP/") {
|
||||
return nil // Skip URI validation for responses
|
||||
}
|
||||
|
||||
uri := parts[1]
|
||||
|
||||
// Basic SIP URI validation
|
||||
if !strings.HasPrefix(strings.ToLower(uri), "sip:") &&
|
||||
!strings.HasPrefix(strings.ToLower(uri), "sips:") &&
|
||||
!strings.HasPrefix(strings.ToLower(uri), "tel:") {
|
||||
return &Violation{
|
||||
Rule: "invalid_request_uri",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Invalid Request-URI scheme",
|
||||
Details: "Expected sip:, sips:, or tel: URI, got: " + uri,
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSuspiciousHeaders detects potentially malicious header patterns
|
||||
func (v *SIPValidator) checkSuspiciousHeaders(headers map[string][]string) []Violation {
|
||||
var violations []Violation
|
||||
|
||||
// Check for extremely long header values
|
||||
for name, values := range headers {
|
||||
for _, val := range values {
|
||||
if len(val) > 1024 {
|
||||
violations = append(violations, Violation{
|
||||
Rule: "suspicious_long_header",
|
||||
Severity: SeverityMedium,
|
||||
Message: "Unusually long header value",
|
||||
Details: name + " has " + strconv.Itoa(len(val)) + " bytes",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for excessive number of Via headers (potential loop)
|
||||
if vias, ok := headers["Via"]; ok && len(vias) > 70 {
|
||||
violations = append(violations, Violation{
|
||||
Rule: "suspicious_via_count",
|
||||
Severity: SeverityMedium,
|
||||
Message: "Excessive Via headers",
|
||||
Details: strconv.Itoa(len(vias)) + " Via headers (potential loop)",
|
||||
})
|
||||
}
|
||||
|
||||
// Check for SQL injection patterns in headers
|
||||
sqlPatterns := []string{
|
||||
"' OR '",
|
||||
"' OR 1=1",
|
||||
"'; DROP",
|
||||
"UNION SELECT",
|
||||
"--",
|
||||
}
|
||||
for name, values := range headers {
|
||||
for _, val := range values {
|
||||
upperVal := strings.ToUpper(val)
|
||||
for _, pattern := range sqlPatterns {
|
||||
if strings.Contains(upperVal, pattern) {
|
||||
violations = append(violations, Violation{
|
||||
Rule: "suspicious_sql_injection",
|
||||
Severity: SeverityHigh,
|
||||
Message: "Potential SQL injection in header",
|
||||
Details: name + " contains suspicious pattern",
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return violations
|
||||
}
|
||||
|
||||
// GetStats returns validator statistics
|
||||
func (v *SIPValidator) GetStats() map[string]interface{} {
|
||||
v.stats.mu.Lock()
|
||||
defer v.stats.mu.Unlock()
|
||||
|
||||
violations := make(map[string]int64)
|
||||
for rule, count := range v.stats.violationsByRule {
|
||||
violations[rule] = count
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_validated": v.stats.totalValidated,
|
||||
"total_valid": v.stats.totalValid,
|
||||
"total_invalid": v.stats.totalInvalid,
|
||||
"violations_by_rule": violations,
|
||||
"config_mode": v.config.Mode,
|
||||
"config_enabled": v.config.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
// IsEnabled returns whether validation is enabled
|
||||
func (v *SIPValidator) IsEnabled() bool {
|
||||
v.mu.RLock()
|
||||
defer v.mu.RUnlock()
|
||||
return v.config.Enabled
|
||||
}
|
||||
|
||||
// abs returns absolute value of an int
|
||||
func abs(n int) int {
|
||||
if n < 0 {
|
||||
return -n
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Cleanup resets old statistics (called periodically)
|
||||
func (v *SIPValidator) Cleanup(maxAge time.Duration) {
|
||||
// For now, validation stats don't need time-based cleanup
|
||||
// Could add per-IP validation tracking in the future
|
||||
}
|
||||
608
validation_test.go
Normal file
608
validation_test.go
Normal file
@ -0,0 +1,608 @@
|
||||
package sipguardian
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestDefaultValidationConfig(t *testing.T) {
|
||||
config := DefaultValidationConfig()
|
||||
|
||||
if !config.Enabled {
|
||||
t.Error("Expected validation to be enabled by default")
|
||||
}
|
||||
if config.Mode != ValidationModePermissive {
|
||||
t.Errorf("Expected permissive mode, got %s", config.Mode)
|
||||
}
|
||||
if config.MaxMessageSize != 65535 {
|
||||
t.Errorf("Expected max message size 65535, got %d", config.MaxMessageSize)
|
||||
}
|
||||
if !config.BanOnNullBytes {
|
||||
t.Error("Expected ban on null bytes by default")
|
||||
}
|
||||
if !config.BanOnBinaryInjection {
|
||||
t.Error("Expected ban on binary injection by default")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationNullBytes(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
validator := NewSIPValidator(logger, DefaultValidationConfig())
|
||||
|
||||
// Valid SIP message
|
||||
validMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 0\r\n\r\n")
|
||||
|
||||
result := validator.Validate(validMsg)
|
||||
if result.ShouldBan {
|
||||
t.Error("Valid message should not trigger ban")
|
||||
}
|
||||
|
||||
// Message with null byte at the end (won't trigger binary injection first)
|
||||
nullMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 5\r\n\r\ntest\x00")
|
||||
|
||||
result = validator.Validate(nullMsg)
|
||||
if !result.ShouldBan {
|
||||
t.Error("Null byte injection should trigger ban")
|
||||
}
|
||||
// Note: null_bytes is checked first, so ban reason should be null_bytes
|
||||
if result.BanReason != "validation_null_bytes" {
|
||||
t.Errorf("Expected ban reason 'validation_null_bytes', got '%s'", result.BanReason)
|
||||
}
|
||||
|
||||
// Check violation was recorded
|
||||
hasNullViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "null_bytes" && v.Severity == SeverityCritical {
|
||||
hasNullViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasNullViolation {
|
||||
t.Error("Expected null_bytes violation with critical severity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationBinaryInjection(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
validator := NewSIPValidator(logger, DefaultValidationConfig())
|
||||
|
||||
// Message with binary control character (bell)
|
||||
binaryMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:\x07user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n\r\n")
|
||||
|
||||
result := validator.Validate(binaryMsg)
|
||||
if !result.ShouldBan {
|
||||
t.Error("Binary injection should trigger ban")
|
||||
}
|
||||
if result.BanReason != "validation_binary_injection" {
|
||||
t.Errorf("Expected ban reason 'validation_binary_injection', got '%s'", result.BanReason)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationMissingHeaders(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModeStrict // Missing headers should ban in strict mode
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Message missing Via header
|
||||
noViaMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result := validator.Validate(noViaMsg)
|
||||
if result.Valid {
|
||||
t.Error("Message missing Via should be invalid")
|
||||
}
|
||||
|
||||
hasViaViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "missing_via" {
|
||||
hasViaViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasViaViolation {
|
||||
t.Error("Expected missing_via violation")
|
||||
}
|
||||
|
||||
// Message missing Call-ID
|
||||
noCallIDMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n\r\n")
|
||||
|
||||
result = validator.Validate(noCallIDMsg)
|
||||
if result.Valid {
|
||||
t.Error("Message missing Call-ID should be invalid")
|
||||
}
|
||||
|
||||
hasCallIDViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "missing_call_id" {
|
||||
hasCallIDViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasCallIDViolation {
|
||||
t.Error("Expected missing_call_id violation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationCompactHeaders(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
validator := NewSIPValidator(logger, DefaultValidationConfig())
|
||||
|
||||
// Valid message using compact header forms
|
||||
compactMsg := []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"v: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK776\r\n" + // v = Via
|
||||
"f: <sip:alice@example.com>;tag=1234\r\n" + // f = From
|
||||
"t: <sip:bob@example.com>\r\n" + // t = To
|
||||
"i: a84b4c76e66710@192.168.1.1\r\n" + // i = Call-ID
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"l: 0\r\n\r\n") // l = Content-Length
|
||||
|
||||
result := validator.Validate(compactMsg)
|
||||
if !result.Valid {
|
||||
t.Errorf("Valid message with compact headers should pass validation: %+v", result.Violations)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationOversizedMessage(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.MaxMessageSize = 500 // Small limit for testing
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Create message larger than limit
|
||||
largeBody := strings.Repeat("X", 600)
|
||||
largeMsg := []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:alice@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:bob@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 600\r\n\r\n" + largeBody)
|
||||
|
||||
result := validator.Validate(largeMsg)
|
||||
if result.Valid {
|
||||
t.Error("Oversized message should be invalid")
|
||||
}
|
||||
|
||||
hasOversizedViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "oversized_message" {
|
||||
hasOversizedViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasOversizedViolation {
|
||||
t.Error("Expected oversized_message violation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationCSeqRange(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
validator := NewSIPValidator(logger, DefaultValidationConfig())
|
||||
|
||||
// Valid CSeq (within range)
|
||||
validCSeqMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 100 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result := validator.Validate(validCSeqMsg)
|
||||
hasCSeqViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "cseq_out_of_range" {
|
||||
hasCSeqViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasCSeqViolation {
|
||||
t.Error("Valid CSeq should not trigger cseq_out_of_range")
|
||||
}
|
||||
|
||||
// Invalid CSeq (out of range - negative not possible as string, but
|
||||
// we can test with very large number exceeding int32)
|
||||
// Note: Testing the actual boundary is tricky since ParseInt may handle it
|
||||
}
|
||||
|
||||
func TestValidationContentLengthMismatch(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
validator := NewSIPValidator(logger, DefaultValidationConfig())
|
||||
|
||||
// Correct Content-Length
|
||||
correctMsg := []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:alice@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:bob@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 4\r\n\r\ntest")
|
||||
|
||||
result := validator.Validate(correctMsg)
|
||||
hasMismatch := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "content_length_mismatch" {
|
||||
hasMismatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasMismatch {
|
||||
t.Error("Correct Content-Length should not trigger mismatch")
|
||||
}
|
||||
|
||||
// Wrong Content-Length (claims 100 but body is 4)
|
||||
wrongMsg := []byte("INVITE sip:bob@example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:alice@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:bob@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 100\r\n\r\ntest")
|
||||
|
||||
result = validator.Validate(wrongMsg)
|
||||
hasMismatch = false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "content_length_mismatch" {
|
||||
hasMismatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasMismatch {
|
||||
t.Error("Wrong Content-Length should trigger mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationViaBranch(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModeStrict // Via branch check only in strict mode
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Valid Via branch (RFC 3261 compliant - starts with z9hG4bK)
|
||||
validBranchMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bKnashds7\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result := validator.Validate(validBranchMsg)
|
||||
hasViaBranchViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "via_invalid_branch" {
|
||||
hasViaBranchViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasViaBranchViolation {
|
||||
t.Error("Valid Via branch should not trigger violation")
|
||||
}
|
||||
|
||||
// Invalid Via branch (doesn't start with z9hG4bK)
|
||||
invalidBranchMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=oldbranch123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result = validator.Validate(invalidBranchMsg)
|
||||
hasViaBranchViolation = false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "via_invalid_branch" {
|
||||
hasViaBranchViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasViaBranchViolation {
|
||||
t.Error("Invalid Via branch should trigger violation in strict mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationDisabledRules(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.DisabledRules = []string{"null_bytes", "missing_via"}
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Message with null byte should NOT trigger ban when rule disabled
|
||||
nullMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:user\x00@example.com>;tag=abc\r\n" + // Null byte + missing Via
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n\r\n")
|
||||
|
||||
result := validator.Validate(nullMsg)
|
||||
|
||||
// Should not have null_bytes violation (disabled)
|
||||
hasNullViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "null_bytes" {
|
||||
hasNullViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasNullViolation {
|
||||
t.Error("Disabled null_bytes rule should not trigger violation")
|
||||
}
|
||||
|
||||
// Should not have missing_via violation (disabled)
|
||||
hasViaViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "missing_via" {
|
||||
hasViaViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasViaViolation {
|
||||
t.Error("Disabled missing_via rule should not trigger violation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationPermissiveMode(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModePermissive
|
||||
config.BanOnNullBytes = false // Disable ban in permissive
|
||||
config.BanOnBinaryInjection = false
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Message missing Via - should log but not ban in permissive
|
||||
noViaMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n\r\n")
|
||||
|
||||
result := validator.Validate(noViaMsg)
|
||||
if result.ShouldBan {
|
||||
t.Error("Permissive mode should not ban for missing headers")
|
||||
}
|
||||
if result.Valid {
|
||||
t.Error("Message should still be marked invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationParanoidMode(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModeParanoid
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Valid message should pass even in paranoid mode
|
||||
validMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n" +
|
||||
"Content-Length: 0\r\n\r\n")
|
||||
|
||||
result := validator.Validate(validMsg)
|
||||
if !result.Valid {
|
||||
t.Errorf("Valid message should pass paranoid validation: %+v", result.Violations)
|
||||
}
|
||||
|
||||
// Message with suspicious SQL-like pattern
|
||||
sqlMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:' OR 1=1--@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result = validator.Validate(sqlMsg)
|
||||
hasSQLViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "suspicious_sql_injection" {
|
||||
hasSQLViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasSQLViolation {
|
||||
t.Error("Paranoid mode should detect SQL injection patterns")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationRequestURI(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModeStrict
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Valid sip: URI
|
||||
validSipMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result := validator.Validate(validSipMsg)
|
||||
hasURIViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "invalid_request_uri" {
|
||||
hasURIViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasURIViolation {
|
||||
t.Error("Valid sip: URI should not trigger violation")
|
||||
}
|
||||
|
||||
// Valid tel: URI
|
||||
validTelMsg := []byte("INVITE tel:+1-555-1234 SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <tel:+1-555-1234>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result = validator.Validate(validTelMsg)
|
||||
hasURIViolation = false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "invalid_request_uri" {
|
||||
hasURIViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasURIViolation {
|
||||
t.Error("Valid tel: URI should not trigger violation")
|
||||
}
|
||||
|
||||
// Invalid URI scheme
|
||||
invalidURIMsg := []byte("INVITE http://example.com/attack SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 INVITE\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
result = validator.Validate(invalidURIMsg)
|
||||
hasURIViolation = false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "invalid_request_uri" {
|
||||
hasURIViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !hasURIViolation {
|
||||
t.Error("Invalid URI scheme should trigger violation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorGlobalInstance(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
|
||||
// Get global validator
|
||||
v1 := GetValidator(logger)
|
||||
v2 := GetValidator(logger)
|
||||
|
||||
if v1 != v2 {
|
||||
t.Error("GetValidator should return the same instance")
|
||||
}
|
||||
|
||||
// Update config
|
||||
config := DefaultValidationConfig()
|
||||
config.MaxMessageSize = 12345
|
||||
SetValidationConfig(config)
|
||||
|
||||
// Should reflect updated config
|
||||
v3 := GetValidator(logger)
|
||||
stats := v3.GetStats()
|
||||
if stats["config_enabled"] != true {
|
||||
t.Error("Config should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatorStats(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// Validate some messages
|
||||
validMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Max-Forwards: 70\r\n\r\n")
|
||||
|
||||
invalidMsg := []byte("REGISTER sip:example.com SIP/2.0\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" + // Missing Via
|
||||
"CSeq: 1 REGISTER\r\n\r\n")
|
||||
|
||||
validator.Validate(validMsg)
|
||||
validator.Validate(validMsg)
|
||||
validator.Validate(invalidMsg)
|
||||
|
||||
stats := validator.GetStats()
|
||||
|
||||
if stats["total_validated"].(int64) != 3 {
|
||||
t.Errorf("Expected 3 total validated, got %v", stats["total_validated"])
|
||||
}
|
||||
if stats["total_valid"].(int64) != 2 {
|
||||
t.Errorf("Expected 2 total valid, got %v", stats["total_valid"])
|
||||
}
|
||||
if stats["total_invalid"].(int64) != 1 {
|
||||
t.Errorf("Expected 1 total invalid, got %v", stats["total_invalid"])
|
||||
}
|
||||
|
||||
violations := stats["violations_by_rule"].(map[string]int64)
|
||||
if violations["missing_via"] != 1 {
|
||||
t.Errorf("Expected 1 missing_via violation, got %d", violations["missing_via"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationSIPResponse(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
config := DefaultValidationConfig()
|
||||
config.Mode = ValidationModeStrict
|
||||
validator := NewSIPValidator(logger, config)
|
||||
|
||||
// SIP response (starts with SIP/2.0)
|
||||
responseMsg := []byte("SIP/2.0 200 OK\r\n" +
|
||||
"Via: SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123;received=10.0.0.1\r\n" +
|
||||
"From: <sip:user@example.com>;tag=abc\r\n" +
|
||||
"To: <sip:user@example.com>;tag=xyz\r\n" +
|
||||
"Call-ID: 123456@192.168.1.1\r\n" +
|
||||
"CSeq: 1 REGISTER\r\n" +
|
||||
"Contact: <sip:user@10.0.0.1:5060>\r\n" +
|
||||
"Content-Length: 0\r\n\r\n")
|
||||
|
||||
result := validator.Validate(responseMsg)
|
||||
|
||||
// Responses should not trigger invalid_request_uri
|
||||
hasURIViolation := false
|
||||
for _, v := range result.Violations {
|
||||
if v.Rule == "invalid_request_uri" {
|
||||
hasURIViolation = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if hasURIViolation {
|
||||
t.Error("SIP responses should not be checked for Request-URI")
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user