caddy-sip-guardian/ratelimit.go
Ryan Malloy c73fa9d3d1 Add extension enumeration detection and comprehensive SIP protection
Major features:
- Extension enumeration detection with 3 detection algorithms:
  - Max unique extensions threshold (default: 20 in 5 min)
  - Sequential pattern detection (e.g., 100,101,102...)
  - Rapid-fire detection (many extensions in short window)
- Prometheus metrics for all SIP Guardian operations
- SQLite persistent storage for bans and attack history
- Webhook notifications for ban/unban/suspicious events
- GeoIP-based country blocking with continent shortcuts
- Per-method rate limiting with token bucket algorithm

Bug fixes:
- Fix whitelist count always reporting zero in stats
- Fix whitelisted connections metric never incrementing
- Fix Caddyfile config not being applied to shared guardian

New files:
- enumeration.go: Extension enumeration detector
- enumeration_test.go: 14 comprehensive unit tests
- metrics.go: Prometheus metrics handler
- storage.go: SQLite persistence layer
- webhooks.go: Webhook notification system
- geoip.go: MaxMind GeoIP integration
- ratelimit.go: Per-method rate limiting

Testing:
- sandbox/ contains complete Docker Compose test environment
- All 14 enumeration tests pass
2025-12-07 15:22:28 -07:00

361 lines
8.9 KiB
Go

package sipguardian
import (
"regexp"
"strings"
"sync"
"time"
"go.uber.org/zap"
)
// SIPMethod represents a SIP request method
type SIPMethod string
const (
MethodREGISTER SIPMethod = "REGISTER"
MethodINVITE SIPMethod = "INVITE"
MethodOPTIONS SIPMethod = "OPTIONS"
MethodACK SIPMethod = "ACK"
MethodBYE SIPMethod = "BYE"
MethodCANCEL SIPMethod = "CANCEL"
MethodINFO SIPMethod = "INFO"
MethodNOTIFY SIPMethod = "NOTIFY"
MethodSUBSCRIBE SIPMethod = "SUBSCRIBE"
MethodMESSAGE SIPMethod = "MESSAGE"
MethodUPDATE SIPMethod = "UPDATE"
MethodPRACK SIPMethod = "PRACK"
MethodREFER SIPMethod = "REFER"
MethodPUBLISH SIPMethod = "PUBLISH"
)
// MethodRateLimit defines rate limits per SIP method
type MethodRateLimit struct {
// Method to rate limit
Method SIPMethod `json:"method"`
// MaxRequests per time window
MaxRequests int `json:"max_requests"`
// Window is the time window for rate limiting
Window time.Duration `json:"window"`
// BurstSize allows temporary bursts above the rate
BurstSize int `json:"burst_size,omitempty"`
}
// RateLimiter provides per-IP, per-method rate limiting
type RateLimiter struct {
limits map[SIPMethod]*MethodRateLimit
buckets map[string]*methodBuckets
logger *zap.Logger
mu sync.RWMutex
// Default limits (used when no specific limit configured)
defaultMaxRequests int
defaultWindow time.Duration
}
// methodBuckets tracks request counts per method for an IP
type methodBuckets struct {
methods map[SIPMethod]*tokenBucket
lastReset time.Time
mu sync.Mutex
}
// tokenBucket implements a simple token bucket algorithm
type tokenBucket struct {
tokens float64
lastUpdate time.Time
maxTokens float64
refillRate float64 // tokens per second
}
// Global rate limiter instance
var (
globalRateLimiter *RateLimiter
rateLimiterMu sync.Mutex
)
// DefaultMethodLimits provides reasonable default rate limits per method
var DefaultMethodLimits = map[SIPMethod]*MethodRateLimit{
MethodREGISTER: {
Method: MethodREGISTER,
MaxRequests: 10,
Window: time.Minute,
BurstSize: 3,
},
MethodINVITE: {
Method: MethodINVITE,
MaxRequests: 30,
Window: time.Minute,
BurstSize: 5,
},
MethodOPTIONS: {
Method: MethodOPTIONS,
MaxRequests: 60,
Window: time.Minute,
BurstSize: 10,
},
MethodSUBSCRIBE: {
Method: MethodSUBSCRIBE,
MaxRequests: 20,
Window: time.Minute,
BurstSize: 5,
},
MethodMESSAGE: {
Method: MethodMESSAGE,
MaxRequests: 100,
Window: time.Minute,
BurstSize: 20,
},
}
// NewRateLimiter creates a new rate limiter
func NewRateLimiter(logger *zap.Logger) *RateLimiter {
return &RateLimiter{
limits: make(map[SIPMethod]*MethodRateLimit),
buckets: make(map[string]*methodBuckets),
logger: logger,
defaultMaxRequests: 100,
defaultWindow: time.Minute,
}
}
// GetRateLimiter returns the global rate limiter
func GetRateLimiter(logger *zap.Logger) *RateLimiter {
rateLimiterMu.Lock()
defer rateLimiterMu.Unlock()
if globalRateLimiter == nil {
globalRateLimiter = NewRateLimiter(logger)
// Apply default limits
for method, limit := range DefaultMethodLimits {
globalRateLimiter.SetLimit(method, limit)
}
}
return globalRateLimiter
}
// SetLimit configures a rate limit for a specific method
func (rl *RateLimiter) SetLimit(method SIPMethod, limit *MethodRateLimit) {
rl.mu.Lock()
defer rl.mu.Unlock()
rl.limits[method] = limit
}
// GetLimit returns the rate limit for a method
func (rl *RateLimiter) GetLimit(method SIPMethod) *MethodRateLimit {
rl.mu.RLock()
defer rl.mu.RUnlock()
if limit, ok := rl.limits[method]; ok {
return limit
}
// Return default limit
return &MethodRateLimit{
Method: method,
MaxRequests: rl.defaultMaxRequests,
Window: rl.defaultWindow,
}
}
// Allow checks if a request should be allowed based on rate limits
// Returns (allowed, reason) - if not allowed, reason explains why
func (rl *RateLimiter) Allow(ip string, method SIPMethod) (bool, string) {
rl.mu.Lock()
bucket, exists := rl.buckets[ip]
if !exists {
bucket = &methodBuckets{
methods: make(map[SIPMethod]*tokenBucket),
lastReset: time.Now(),
}
rl.buckets[ip] = bucket
}
rl.mu.Unlock()
limit := rl.GetLimit(method)
bucket.mu.Lock()
defer bucket.mu.Unlock()
tb, exists := bucket.methods[method]
if !exists {
// Create new token bucket for this method
burstSize := limit.BurstSize
if burstSize == 0 {
burstSize = limit.MaxRequests / 5 // Default burst is 20% of max
if burstSize < 1 {
burstSize = 1
}
}
tb = &tokenBucket{
tokens: float64(burstSize),
lastUpdate: time.Now(),
maxTokens: float64(burstSize),
refillRate: float64(limit.MaxRequests) / limit.Window.Seconds(),
}
bucket.methods[method] = tb
}
// Refill tokens based on elapsed time
now := time.Now()
elapsed := now.Sub(tb.lastUpdate).Seconds()
tb.tokens += elapsed * tb.refillRate
if tb.tokens > tb.maxTokens {
tb.tokens = tb.maxTokens
}
tb.lastUpdate = now
// Check if we have tokens
if tb.tokens >= 1.0 {
tb.tokens -= 1.0
return true, ""
}
// Rate limited
return false, "rate_limit_" + string(method)
}
// Cleanup removes old entries
func (rl *RateLimiter) Cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
cutoff := now.Add(-10 * time.Minute)
for ip, bucket := range rl.buckets {
bucket.mu.Lock()
if bucket.lastReset.Before(cutoff) {
delete(rl.buckets, ip)
}
bucket.mu.Unlock()
}
}
// GetStats returns rate limiter statistics
func (rl *RateLimiter) GetStats() map[string]interface{} {
rl.mu.RLock()
defer rl.mu.RUnlock()
limitConfigs := make(map[string]interface{})
for method, limit := range rl.limits {
limitConfigs[string(method)] = map[string]interface{}{
"max_requests": limit.MaxRequests,
"window": limit.Window.String(),
"burst_size": limit.BurstSize,
}
}
return map[string]interface{}{
"tracked_ips": len(rl.buckets),
"limits": limitConfigs,
}
}
// SIP method extraction patterns
var (
sipMethodPattern = regexp.MustCompile(`^(REGISTER|INVITE|OPTIONS|ACK|BYE|CANCEL|INFO|NOTIFY|SUBSCRIBE|MESSAGE|UPDATE|PRACK|REFER|PUBLISH)\s+sip:`)
// Pattern to extract target extension from Request-URI: METHOD sip:extension@domain
sipTargetExtPattern = regexp.MustCompile(`^(?:REGISTER|INVITE|OPTIONS|MESSAGE|SUBSCRIBE)\s+sip:([^@\s>]+)@`)
)
// ExtractSIPMethod extracts the SIP method from a message
func ExtractSIPMethod(data []byte) SIPMethod {
s := string(data)
if matches := sipMethodPattern.FindStringSubmatch(s); len(matches) > 1 {
return SIPMethod(matches[1])
}
return ""
}
// ExtractTargetExtension extracts the target extension from a SIP message
// It looks at the Request-URI first, then falls back to the To header
func ExtractTargetExtension(data []byte) string {
s := string(data)
// First try Request-URI: REGISTER sip:1001@domain.com
if matches := sipTargetExtPattern.FindStringSubmatch(s); len(matches) > 1 {
ext := matches[1]
// Filter out domain-like values and overly long strings
if !strings.Contains(ext, ".") && len(ext) <= 10 && len(ext) > 0 {
return ext
}
}
// Fall back to To header
user, _ := ParseToHeader(data)
if user != "" && !strings.Contains(user, ".") && len(user) <= 10 {
return user
}
return ""
}
// ParseUserAgent extracts User-Agent from SIP message
func ParseUserAgent(data []byte) string {
s := string(data)
lines := strings.Split(s, "\r\n")
for _, line := range lines {
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "user-agent:") {
return strings.TrimSpace(line[11:])
}
}
return ""
}
// ParseFromHeader extracts From header info
func ParseFromHeader(data []byte) (user, domain string) {
s := string(data)
lines := strings.Split(s, "\r\n")
for _, line := range lines {
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "from:") {
// Extract sip:user@domain from From header
fromPattern := regexp.MustCompile(`sip:([^@]+)@([^>;\s]+)`)
if matches := fromPattern.FindStringSubmatch(line); len(matches) > 2 {
return matches[1], matches[2]
}
}
}
return "", ""
}
// ParseToHeader extracts To header info
func ParseToHeader(data []byte) (user, domain string) {
s := string(data)
lines := strings.Split(s, "\r\n")
for _, line := range lines {
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "to:") {
// Extract sip:user@domain from To header
toPattern := regexp.MustCompile(`sip:([^@]+)@([^>;\s]+)`)
if matches := toPattern.FindStringSubmatch(line); len(matches) > 2 {
return matches[1], matches[2]
}
}
}
return "", ""
}
// ParseCallID extracts Call-ID from SIP message
func ParseCallID(data []byte) string {
s := string(data)
lines := strings.Split(s, "\r\n")
for _, line := range lines {
lower := strings.ToLower(line)
if strings.HasPrefix(lower, "call-id:") {
return strings.TrimSpace(line[8:])
}
if strings.HasPrefix(lower, "i:") { // Short form
return strings.TrimSpace(line[2:])
}
}
return ""
}