From 265c6061693f491e2f54c2b8b2d9ab797c699eb0 Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Mon, 8 Dec 2025 01:29:16 -0700 Subject: [PATCH] Improve Caddy module lifecycle and safety - Add Cleanup() method (caddy.CleanerUpper) to stop goroutines on config reload, preventing goroutine leaks - Add Validate() method (caddy.Validator) for early config validation with reasonable bounds checking - Add public BanIP() method for admin handler, replacing direct internal state manipulation - Add bounds checking for failure tracker and ban maps to prevent memory exhaustion under DDoS (100k/50k limits) - Add eviction functions to proactively clean oldest entries when at capacity --- admin.go | 9 +- sipguardian.go | 258 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 256 insertions(+), 11 deletions(-) diff --git a/admin.go b/admin.go index e7f0462..5e56558 100644 --- a/admin.go +++ b/admin.go @@ -206,13 +206,8 @@ func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path st body.Reason = "manual_ban" } - // Force a ban by recording max failures - h.guardian.mu.Lock() - h.guardian.failureCounts[ip] = &failureTracker{ - count: h.guardian.MaxFailures, - } - h.guardian.banIP(ip, body.Reason) - h.guardian.mu.Unlock() + // Use public BanIP method for proper encapsulation + h.guardian.BanIP(ip, body.Reason) w.Header().Set("Content-Type", "application/json") return json.NewEncoder(w).Encode(map[string]interface{}{ diff --git a/sipguardian.go b/sipguardian.go index 98d48cd..75dc711 100644 --- a/sipguardian.go +++ b/sipguardian.go @@ -20,6 +20,13 @@ var ( enableStorage = true ) +// Configuration limits to prevent unbounded growth under attack +const ( + maxTrackedIPs = 100000 // Max IPs to track failures for + maxBannedIPs = 50000 // Max banned IPs to hold in memory + cleanupBatchSize = 1000 // Max entries to clean per cycle +) + func init() { caddy.RegisterModule(SIPGuardian{}) } @@ -72,6 +79,10 @@ type SIPGuardian struct { mu sync.RWMutex storage *Storage geoIP *GeoIPLookup + + // Lifecycle management + stopCh chan struct{} + wg sync.WaitGroup } type failureTracker struct { @@ -93,6 +104,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { g.logger = ctx.Logger() g.bannedIPs = make(map[string]*BanEntry) g.failureCounts = make(map[string]*failureTracker) + g.stopCh = make(chan struct{}) // Set defaults if g.MaxFailures == 0 { @@ -209,8 +221,9 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { ) } - // Start cleanup goroutine - go g.cleanupLoop(ctx) + // Start cleanup goroutine with proper lifecycle tracking + g.wg.Add(1) + go g.cleanupLoop() g.logger.Info("SIP Guardian initialized", zap.Int("max_failures", g.MaxFailures), @@ -341,6 +354,20 @@ func (g *SIPGuardian) RecordFailure(ip, reason string) bool { tracker, exists := g.failureCounts[ip] if !exists || now.Sub(tracker.firstSeen) > findWindow { + // Check bounds before adding new entry + if !exists && len(g.failureCounts) >= maxTrackedIPs { + // Proactively clean old entries to make room + g.evictOldestTrackers(cleanupBatchSize) + if len(g.failureCounts) >= maxTrackedIPs { + // Still at limit, log warning and skip tracking + g.logger.Warn("Failure tracker map at capacity, dropping new entry", + zap.String("ip", ip), + zap.Int("capacity", maxTrackedIPs), + ) + return false + } + } + // Start new tracking window tracker = &failureTracker{ count: 1, @@ -391,6 +418,19 @@ func (g *SIPGuardian) banIP(ip, reason string) { now := time.Now() banDuration := time.Duration(g.BanTime) + // Check bounds before adding new ban entry + if _, exists := g.bannedIPs[ip]; !exists && len(g.bannedIPs) >= maxBannedIPs { + // Proactively clean expired/oldest bans to make room + g.evictOldestBans(cleanupBatchSize) + if len(g.bannedIPs) >= maxBannedIPs { + // Still at limit, log warning but still ban (overwrite random existing) + g.logger.Warn("Ban map at capacity, evicting to make room", + zap.String("ip", ip), + zap.Int("capacity", maxBannedIPs), + ) + } + } + hitCount := 0 if tracker := g.failureCounts[ip]; tracker != nil { hitCount = tracker.count @@ -529,13 +569,16 @@ func (g *SIPGuardian) RefreshDNSWhitelist() { } // cleanupLoop periodically removes expired entries -func (g *SIPGuardian) cleanupLoop(ctx caddy.Context) { +func (g *SIPGuardian) cleanupLoop() { + defer g.wg.Done() + ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { - case <-ctx.Done(): + case <-g.stopCh: + g.logger.Debug("Cleanup loop stopped") return case <-ticker.C: g.cleanup() @@ -564,6 +607,109 @@ func (g *SIPGuardian) cleanup() { delete(g.failureCounts, ip) } } + + // Log map sizes for monitoring + if len(g.bannedIPs) > maxBannedIPs/2 || len(g.failureCounts) > maxTrackedIPs/2 { + g.logger.Info("Map size status", + zap.Int("banned_ips", len(g.bannedIPs)), + zap.Int("tracked_ips", len(g.failureCounts)), + ) + } +} + +// evictOldestTrackers removes the oldest failure trackers to make room (must hold lock) +func (g *SIPGuardian) evictOldestTrackers(count int) { + // Find and evict the oldest entries by firstSeen + type ipTime struct { + ip string + time time.Time + } + + // Collect all entries + entries := make([]ipTime, 0, len(g.failureCounts)) + for ip, tracker := range g.failureCounts { + entries = append(entries, ipTime{ip: ip, time: tracker.firstSeen}) + } + + // Sort by time (oldest first) + for i := 0; i < len(entries)-1; i++ { + for j := i + 1; j < len(entries); j++ { + if entries[j].time.Before(entries[i].time) { + entries[i], entries[j] = entries[j], entries[i] + } + } + } + + // Evict oldest entries + evicted := 0 + for _, entry := range entries { + if evicted >= count { + break + } + delete(g.failureCounts, entry.ip) + evicted++ + } + + if evicted > 0 { + g.logger.Info("Evicted oldest failure trackers", + zap.Int("evicted", evicted), + zap.Int("remaining", len(g.failureCounts)), + ) + } +} + +// evictOldestBans removes the oldest expired bans to make room (must hold lock) +func (g *SIPGuardian) evictOldestBans(count int) { + now := time.Now() + + // First, remove expired bans + for ip, entry := range g.bannedIPs { + if now.After(entry.ExpiresAt) { + delete(g.bannedIPs, ip) + count-- + if count <= 0 { + return + } + } + } + + // If still need room, remove bans closest to expiry + if count > 0 { + type ipTime struct { + ip string + time time.Time + } + + entries := make([]ipTime, 0, len(g.bannedIPs)) + for ip, ban := range g.bannedIPs { + entries = append(entries, ipTime{ip: ip, time: ban.ExpiresAt}) + } + + // Sort by expiry time (soonest first) + for i := 0; i < len(entries)-1; i++ { + for j := i + 1; j < len(entries); j++ { + if entries[j].time.Before(entries[i].time) { + entries[i], entries[j] = entries[j], entries[i] + } + } + } + + evicted := 0 + for _, entry := range entries { + if evicted >= count { + break + } + delete(g.bannedIPs, entry.ip) + evicted++ + } + + if evicted > 0 { + g.logger.Warn("Evicted active bans due to capacity limit", + zap.Int("evicted", evicted), + zap.Int("remaining", len(g.bannedIPs)), + ) + } + } } // UnmarshalCaddyfile implements caddyfile.Unmarshaler. @@ -742,9 +888,113 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { return nil } +// Cleanup implements caddy.CleanerUpper. +// Called when Caddy config is reloaded to stop goroutines and release resources. +func (g *SIPGuardian) Cleanup() error { + g.logger.Info("SIP Guardian cleanup starting") + + // Signal all goroutines to stop + close(g.stopCh) + + // Stop DNS whitelist background refresh + if g.dnsWhitelist != nil { + g.dnsWhitelist.Stop() + } + + // Wait for goroutines to finish (with timeout) + done := make(chan struct{}) + go func() { + g.wg.Wait() + close(done) + }() + + select { + case <-done: + g.logger.Debug("All goroutines stopped cleanly") + case <-time.After(5 * time.Second): + g.logger.Warn("Timeout waiting for goroutines to stop") + } + + // Close storage connection + if g.storage != nil { + if err := g.storage.Close(); err != nil { + g.logger.Error("Error closing storage", zap.Error(err)) + } + } + + // Close GeoIP database + if g.geoIP != nil { + g.geoIP.Close() + } + + g.logger.Info("SIP Guardian cleanup complete") + return nil +} + +// Validate implements caddy.Validator. +// Called after Provision() to validate configuration before use. +func (g *SIPGuardian) Validate() error { + if g.MaxFailures < 1 { + return fmt.Errorf("max_failures must be at least 1, got %d", g.MaxFailures) + } + if g.MaxFailures > 1000 { + return fmt.Errorf("max_failures exceeds reasonable limit (1000), got %d", g.MaxFailures) + } + + if time.Duration(g.FindTime) < time.Second { + return fmt.Errorf("find_time must be at least 1s, got %v", time.Duration(g.FindTime)) + } + if time.Duration(g.FindTime) > 24*time.Hour { + return fmt.Errorf("find_time exceeds reasonable limit (24h), got %v", time.Duration(g.FindTime)) + } + + if time.Duration(g.BanTime) < time.Second { + return fmt.Errorf("ban_time must be at least 1s, got %v", time.Duration(g.BanTime)) + } + + // Validate conflicting country configurations + if len(g.AllowedCountries) > 0 && len(g.BlockedCountries) > 0 { + return fmt.Errorf("cannot specify both allowed_countries and blocked_countries") + } + + // Validate DNS refresh interval + if g.DNSRefresh > 0 && time.Duration(g.DNSRefresh) < 30*time.Second { + return fmt.Errorf("dns_refresh must be at least 30s for stability, got %v", time.Duration(g.DNSRefresh)) + } + + return nil +} + +// BanIP manually adds an IP to the ban list with a reason. +// This is the public API for external callers (like AdminHandler). +func (g *SIPGuardian) BanIP(ip, reason string) { + if g.IsWhitelisted(ip) { + g.logger.Info("Attempted to ban whitelisted IP", zap.String("ip", ip)) + return + } + + g.mu.Lock() + defer g.mu.Unlock() + + // Create a failure tracker if needed (for hit count) + if _, exists := g.failureCounts[ip]; !exists { + g.failureCounts[ip] = &failureTracker{ + count: g.MaxFailures, + firstSeen: time.Now(), + lastSeen: time.Now(), + } + } else { + g.failureCounts[ip].count = g.MaxFailures + } + + g.banIP(ip, reason) +} + // Interface guards var ( _ caddy.Module = (*SIPGuardian)(nil) _ caddy.Provisioner = (*SIPGuardian)(nil) + _ caddy.CleanerUpper = (*SIPGuardian)(nil) + _ caddy.Validator = (*SIPGuardian)(nil) _ caddyfile.Unmarshaler = (*SIPGuardian)(nil) )