package sipguardian import ( "regexp" "strings" "sync" "time" "go.uber.org/zap" ) // SIPMethod represents a SIP request method type SIPMethod string const ( MethodREGISTER SIPMethod = "REGISTER" MethodINVITE SIPMethod = "INVITE" MethodOPTIONS SIPMethod = "OPTIONS" MethodACK SIPMethod = "ACK" MethodBYE SIPMethod = "BYE" MethodCANCEL SIPMethod = "CANCEL" MethodINFO SIPMethod = "INFO" MethodNOTIFY SIPMethod = "NOTIFY" MethodSUBSCRIBE SIPMethod = "SUBSCRIBE" MethodMESSAGE SIPMethod = "MESSAGE" MethodUPDATE SIPMethod = "UPDATE" MethodPRACK SIPMethod = "PRACK" MethodREFER SIPMethod = "REFER" MethodPUBLISH SIPMethod = "PUBLISH" ) // MethodRateLimit defines rate limits per SIP method type MethodRateLimit struct { // Method to rate limit Method SIPMethod `json:"method"` // MaxRequests per time window MaxRequests int `json:"max_requests"` // Window is the time window for rate limiting Window time.Duration `json:"window"` // BurstSize allows temporary bursts above the rate BurstSize int `json:"burst_size,omitempty"` } // RateLimiter provides per-IP, per-method rate limiting type RateLimiter struct { limits map[SIPMethod]*MethodRateLimit buckets map[string]*methodBuckets logger *zap.Logger mu sync.RWMutex // Default limits (used when no specific limit configured) defaultMaxRequests int defaultWindow time.Duration } // methodBuckets tracks request counts per method for an IP type methodBuckets struct { methods map[SIPMethod]*tokenBucket lastReset time.Time mu sync.Mutex } // tokenBucket implements a simple token bucket algorithm type tokenBucket struct { tokens float64 lastUpdate time.Time maxTokens float64 refillRate float64 // tokens per second } // Global rate limiter instance var ( globalRateLimiter *RateLimiter rateLimiterMu sync.Mutex ) // DefaultMethodLimits provides reasonable default rate limits per method var DefaultMethodLimits = map[SIPMethod]*MethodRateLimit{ MethodREGISTER: { Method: MethodREGISTER, MaxRequests: 10, Window: time.Minute, BurstSize: 3, }, MethodINVITE: { Method: MethodINVITE, MaxRequests: 30, Window: time.Minute, BurstSize: 5, }, MethodOPTIONS: { Method: MethodOPTIONS, MaxRequests: 60, Window: time.Minute, BurstSize: 10, }, MethodSUBSCRIBE: { Method: MethodSUBSCRIBE, MaxRequests: 20, Window: time.Minute, BurstSize: 5, }, MethodMESSAGE: { Method: MethodMESSAGE, MaxRequests: 100, Window: time.Minute, BurstSize: 20, }, } // NewRateLimiter creates a new rate limiter func NewRateLimiter(logger *zap.Logger) *RateLimiter { return &RateLimiter{ limits: make(map[SIPMethod]*MethodRateLimit), buckets: make(map[string]*methodBuckets), logger: logger, defaultMaxRequests: 100, defaultWindow: time.Minute, } } // GetRateLimiter returns the global rate limiter func GetRateLimiter(logger *zap.Logger) *RateLimiter { rateLimiterMu.Lock() defer rateLimiterMu.Unlock() if globalRateLimiter == nil { globalRateLimiter = NewRateLimiter(logger) // Apply default limits for method, limit := range DefaultMethodLimits { globalRateLimiter.SetLimit(method, limit) } } return globalRateLimiter } // SetLimit configures a rate limit for a specific method func (rl *RateLimiter) SetLimit(method SIPMethod, limit *MethodRateLimit) { rl.mu.Lock() defer rl.mu.Unlock() rl.limits[method] = limit } // GetLimit returns the rate limit for a method func (rl *RateLimiter) GetLimit(method SIPMethod) *MethodRateLimit { rl.mu.RLock() defer rl.mu.RUnlock() if limit, ok := rl.limits[method]; ok { return limit } // Return default limit return &MethodRateLimit{ Method: method, MaxRequests: rl.defaultMaxRequests, Window: rl.defaultWindow, } } // Allow checks if a request should be allowed based on rate limits // Returns (allowed, reason) - if not allowed, reason explains why func (rl *RateLimiter) Allow(ip string, method SIPMethod) (bool, string) { rl.mu.Lock() bucket, exists := rl.buckets[ip] if !exists { bucket = &methodBuckets{ methods: make(map[SIPMethod]*tokenBucket), lastReset: time.Now(), } rl.buckets[ip] = bucket } rl.mu.Unlock() limit := rl.GetLimit(method) bucket.mu.Lock() defer bucket.mu.Unlock() tb, exists := bucket.methods[method] if !exists { // Create new token bucket for this method burstSize := limit.BurstSize if burstSize == 0 { burstSize = limit.MaxRequests / 5 // Default burst is 20% of max if burstSize < 1 { burstSize = 1 } } tb = &tokenBucket{ tokens: float64(burstSize), lastUpdate: time.Now(), maxTokens: float64(burstSize), refillRate: float64(limit.MaxRequests) / limit.Window.Seconds(), } bucket.methods[method] = tb } // Refill tokens based on elapsed time now := time.Now() elapsed := now.Sub(tb.lastUpdate).Seconds() tb.tokens += elapsed * tb.refillRate if tb.tokens > tb.maxTokens { tb.tokens = tb.maxTokens } tb.lastUpdate = now // Check if we have tokens if tb.tokens >= 1.0 { tb.tokens -= 1.0 return true, "" } // Rate limited return false, "rate_limit_" + string(method) } // Cleanup removes old entries func (rl *RateLimiter) Cleanup() { rl.mu.Lock() defer rl.mu.Unlock() now := time.Now() cutoff := now.Add(-10 * time.Minute) for ip, bucket := range rl.buckets { bucket.mu.Lock() if bucket.lastReset.Before(cutoff) { delete(rl.buckets, ip) } bucket.mu.Unlock() } } // GetStats returns rate limiter statistics func (rl *RateLimiter) GetStats() map[string]interface{} { rl.mu.RLock() defer rl.mu.RUnlock() limitConfigs := make(map[string]interface{}) for method, limit := range rl.limits { limitConfigs[string(method)] = map[string]interface{}{ "max_requests": limit.MaxRequests, "window": limit.Window.String(), "burst_size": limit.BurstSize, } } return map[string]interface{}{ "tracked_ips": len(rl.buckets), "limits": limitConfigs, } } // SIP method extraction patterns var ( sipMethodPattern = regexp.MustCompile(`^(REGISTER|INVITE|OPTIONS|ACK|BYE|CANCEL|INFO|NOTIFY|SUBSCRIBE|MESSAGE|UPDATE|PRACK|REFER|PUBLISH)\s+sip:`) // Pattern to extract target extension from Request-URI: METHOD sip:extension@domain sipTargetExtPattern = regexp.MustCompile(`^(?:REGISTER|INVITE|OPTIONS|MESSAGE|SUBSCRIBE)\s+sip:([^@\s>]+)@`) ) // ExtractSIPMethod extracts the SIP method from a message func ExtractSIPMethod(data []byte) SIPMethod { s := string(data) if matches := sipMethodPattern.FindStringSubmatch(s); len(matches) > 1 { return SIPMethod(matches[1]) } return "" } // ExtractTargetExtension extracts the target extension from a SIP message // It looks at the Request-URI first, then falls back to the To header func ExtractTargetExtension(data []byte) string { s := string(data) // First try Request-URI: REGISTER sip:1001@domain.com if matches := sipTargetExtPattern.FindStringSubmatch(s); len(matches) > 1 { ext := matches[1] // Filter out domain-like values and overly long strings if !strings.Contains(ext, ".") && len(ext) <= 10 && len(ext) > 0 { return ext } } // Fall back to To header user, _ := ParseToHeader(data) if user != "" && !strings.Contains(user, ".") && len(user) <= 10 { return user } return "" } // ParseUserAgent extracts User-Agent from SIP message func ParseUserAgent(data []byte) string { s := string(data) lines := strings.Split(s, "\r\n") for _, line := range lines { lower := strings.ToLower(line) if strings.HasPrefix(lower, "user-agent:") { return strings.TrimSpace(line[11:]) } } return "" } // ParseFromHeader extracts From header info func ParseFromHeader(data []byte) (user, domain string) { s := string(data) lines := strings.Split(s, "\r\n") for _, line := range lines { lower := strings.ToLower(line) if strings.HasPrefix(lower, "from:") { // Extract sip:user@domain from From header fromPattern := regexp.MustCompile(`sip:([^@]+)@([^>;\s]+)`) if matches := fromPattern.FindStringSubmatch(line); len(matches) > 2 { return matches[1], matches[2] } } } return "", "" } // ParseToHeader extracts To header info func ParseToHeader(data []byte) (user, domain string) { s := string(data) lines := strings.Split(s, "\r\n") for _, line := range lines { lower := strings.ToLower(line) if strings.HasPrefix(lower, "to:") { // Extract sip:user@domain from To header toPattern := regexp.MustCompile(`sip:([^@]+)@([^>;\s]+)`) if matches := toPattern.FindStringSubmatch(line); len(matches) > 2 { return matches[1], matches[2] } } } return "", "" } // ParseCallID extracts Call-ID from SIP message func ParseCallID(data []byte) string { s := string(data) lines := strings.Split(s, "\r\n") for _, line := range lines { lower := strings.ToLower(line) if strings.HasPrefix(lower, "call-id:") { return strings.TrimSpace(line[8:]) } if strings.HasPrefix(lower, "i:") { // Short form return strings.TrimSpace(line[2:]) } } return "" }