Improve Caddy module lifecycle and safety

- Add Cleanup() method (caddy.CleanerUpper) to stop goroutines on config
  reload, preventing goroutine leaks
- Add Validate() method (caddy.Validator) for early config validation with
  reasonable bounds checking
- Add public BanIP() method for admin handler, replacing direct internal
  state manipulation
- Add bounds checking for failure tracker and ban maps to prevent memory
  exhaustion under DDoS (100k/50k limits)
- Add eviction functions to proactively clean oldest entries when at capacity
This commit is contained in:
Ryan Malloy 2025-12-08 01:29:16 -07:00
parent 5cf34eb3c0
commit 265c606169
2 changed files with 256 additions and 11 deletions

View File

@ -206,13 +206,8 @@ func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path st
body.Reason = "manual_ban" body.Reason = "manual_ban"
} }
// Force a ban by recording max failures // Use public BanIP method for proper encapsulation
h.guardian.mu.Lock() h.guardian.BanIP(ip, body.Reason)
h.guardian.failureCounts[ip] = &failureTracker{
count: h.guardian.MaxFailures,
}
h.guardian.banIP(ip, body.Reason)
h.guardian.mu.Unlock()
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
return json.NewEncoder(w).Encode(map[string]interface{}{ return json.NewEncoder(w).Encode(map[string]interface{}{

View File

@ -20,6 +20,13 @@ var (
enableStorage = true enableStorage = true
) )
// Configuration limits to prevent unbounded growth under attack
const (
maxTrackedIPs = 100000 // Max IPs to track failures for
maxBannedIPs = 50000 // Max banned IPs to hold in memory
cleanupBatchSize = 1000 // Max entries to clean per cycle
)
func init() { func init() {
caddy.RegisterModule(SIPGuardian{}) caddy.RegisterModule(SIPGuardian{})
} }
@ -72,6 +79,10 @@ type SIPGuardian struct {
mu sync.RWMutex mu sync.RWMutex
storage *Storage storage *Storage
geoIP *GeoIPLookup geoIP *GeoIPLookup
// Lifecycle management
stopCh chan struct{}
wg sync.WaitGroup
} }
type failureTracker struct { type failureTracker struct {
@ -93,6 +104,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
g.logger = ctx.Logger() g.logger = ctx.Logger()
g.bannedIPs = make(map[string]*BanEntry) g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker) g.failureCounts = make(map[string]*failureTracker)
g.stopCh = make(chan struct{})
// Set defaults // Set defaults
if g.MaxFailures == 0 { if g.MaxFailures == 0 {
@ -209,8 +221,9 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
) )
} }
// Start cleanup goroutine // Start cleanup goroutine with proper lifecycle tracking
go g.cleanupLoop(ctx) g.wg.Add(1)
go g.cleanupLoop()
g.logger.Info("SIP Guardian initialized", g.logger.Info("SIP Guardian initialized",
zap.Int("max_failures", g.MaxFailures), zap.Int("max_failures", g.MaxFailures),
@ -341,6 +354,20 @@ func (g *SIPGuardian) RecordFailure(ip, reason string) bool {
tracker, exists := g.failureCounts[ip] tracker, exists := g.failureCounts[ip]
if !exists || now.Sub(tracker.firstSeen) > findWindow { if !exists || now.Sub(tracker.firstSeen) > findWindow {
// Check bounds before adding new entry
if !exists && len(g.failureCounts) >= maxTrackedIPs {
// Proactively clean old entries to make room
g.evictOldestTrackers(cleanupBatchSize)
if len(g.failureCounts) >= maxTrackedIPs {
// Still at limit, log warning and skip tracking
g.logger.Warn("Failure tracker map at capacity, dropping new entry",
zap.String("ip", ip),
zap.Int("capacity", maxTrackedIPs),
)
return false
}
}
// Start new tracking window // Start new tracking window
tracker = &failureTracker{ tracker = &failureTracker{
count: 1, count: 1,
@ -391,6 +418,19 @@ func (g *SIPGuardian) banIP(ip, reason string) {
now := time.Now() now := time.Now()
banDuration := time.Duration(g.BanTime) banDuration := time.Duration(g.BanTime)
// Check bounds before adding new ban entry
if _, exists := g.bannedIPs[ip]; !exists && len(g.bannedIPs) >= maxBannedIPs {
// Proactively clean expired/oldest bans to make room
g.evictOldestBans(cleanupBatchSize)
if len(g.bannedIPs) >= maxBannedIPs {
// Still at limit, log warning but still ban (overwrite random existing)
g.logger.Warn("Ban map at capacity, evicting to make room",
zap.String("ip", ip),
zap.Int("capacity", maxBannedIPs),
)
}
}
hitCount := 0 hitCount := 0
if tracker := g.failureCounts[ip]; tracker != nil { if tracker := g.failureCounts[ip]; tracker != nil {
hitCount = tracker.count hitCount = tracker.count
@ -529,13 +569,16 @@ func (g *SIPGuardian) RefreshDNSWhitelist() {
} }
// cleanupLoop periodically removes expired entries // cleanupLoop periodically removes expired entries
func (g *SIPGuardian) cleanupLoop(ctx caddy.Context) { func (g *SIPGuardian) cleanupLoop() {
defer g.wg.Done()
ticker := time.NewTicker(1 * time.Minute) ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ctx.Done(): case <-g.stopCh:
g.logger.Debug("Cleanup loop stopped")
return return
case <-ticker.C: case <-ticker.C:
g.cleanup() g.cleanup()
@ -564,6 +607,109 @@ func (g *SIPGuardian) cleanup() {
delete(g.failureCounts, ip) delete(g.failureCounts, ip)
} }
} }
// Log map sizes for monitoring
if len(g.bannedIPs) > maxBannedIPs/2 || len(g.failureCounts) > maxTrackedIPs/2 {
g.logger.Info("Map size status",
zap.Int("banned_ips", len(g.bannedIPs)),
zap.Int("tracked_ips", len(g.failureCounts)),
)
}
}
// evictOldestTrackers removes the oldest failure trackers to make room (must hold lock)
func (g *SIPGuardian) evictOldestTrackers(count int) {
// Find and evict the oldest entries by firstSeen
type ipTime struct {
ip string
time time.Time
}
// Collect all entries
entries := make([]ipTime, 0, len(g.failureCounts))
for ip, tracker := range g.failureCounts {
entries = append(entries, ipTime{ip: ip, time: tracker.firstSeen})
}
// Sort by time (oldest first)
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].time.Before(entries[i].time) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
// Evict oldest entries
evicted := 0
for _, entry := range entries {
if evicted >= count {
break
}
delete(g.failureCounts, entry.ip)
evicted++
}
if evicted > 0 {
g.logger.Info("Evicted oldest failure trackers",
zap.Int("evicted", evicted),
zap.Int("remaining", len(g.failureCounts)),
)
}
}
// evictOldestBans removes the oldest expired bans to make room (must hold lock)
func (g *SIPGuardian) evictOldestBans(count int) {
now := time.Now()
// First, remove expired bans
for ip, entry := range g.bannedIPs {
if now.After(entry.ExpiresAt) {
delete(g.bannedIPs, ip)
count--
if count <= 0 {
return
}
}
}
// If still need room, remove bans closest to expiry
if count > 0 {
type ipTime struct {
ip string
time time.Time
}
entries := make([]ipTime, 0, len(g.bannedIPs))
for ip, ban := range g.bannedIPs {
entries = append(entries, ipTime{ip: ip, time: ban.ExpiresAt})
}
// Sort by expiry time (soonest first)
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].time.Before(entries[i].time) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
evicted := 0
for _, entry := range entries {
if evicted >= count {
break
}
delete(g.bannedIPs, entry.ip)
evicted++
}
if evicted > 0 {
g.logger.Warn("Evicted active bans due to capacity limit",
zap.Int("evicted", evicted),
zap.Int("remaining", len(g.bannedIPs)),
)
}
}
} }
// UnmarshalCaddyfile implements caddyfile.Unmarshaler. // UnmarshalCaddyfile implements caddyfile.Unmarshaler.
@ -742,9 +888,113 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return nil return nil
} }
// Cleanup implements caddy.CleanerUpper.
// Called when Caddy config is reloaded to stop goroutines and release resources.
func (g *SIPGuardian) Cleanup() error {
g.logger.Info("SIP Guardian cleanup starting")
// Signal all goroutines to stop
close(g.stopCh)
// Stop DNS whitelist background refresh
if g.dnsWhitelist != nil {
g.dnsWhitelist.Stop()
}
// Wait for goroutines to finish (with timeout)
done := make(chan struct{})
go func() {
g.wg.Wait()
close(done)
}()
select {
case <-done:
g.logger.Debug("All goroutines stopped cleanly")
case <-time.After(5 * time.Second):
g.logger.Warn("Timeout waiting for goroutines to stop")
}
// Close storage connection
if g.storage != nil {
if err := g.storage.Close(); err != nil {
g.logger.Error("Error closing storage", zap.Error(err))
}
}
// Close GeoIP database
if g.geoIP != nil {
g.geoIP.Close()
}
g.logger.Info("SIP Guardian cleanup complete")
return nil
}
// Validate implements caddy.Validator.
// Called after Provision() to validate configuration before use.
func (g *SIPGuardian) Validate() error {
if g.MaxFailures < 1 {
return fmt.Errorf("max_failures must be at least 1, got %d", g.MaxFailures)
}
if g.MaxFailures > 1000 {
return fmt.Errorf("max_failures exceeds reasonable limit (1000), got %d", g.MaxFailures)
}
if time.Duration(g.FindTime) < time.Second {
return fmt.Errorf("find_time must be at least 1s, got %v", time.Duration(g.FindTime))
}
if time.Duration(g.FindTime) > 24*time.Hour {
return fmt.Errorf("find_time exceeds reasonable limit (24h), got %v", time.Duration(g.FindTime))
}
if time.Duration(g.BanTime) < time.Second {
return fmt.Errorf("ban_time must be at least 1s, got %v", time.Duration(g.BanTime))
}
// Validate conflicting country configurations
if len(g.AllowedCountries) > 0 && len(g.BlockedCountries) > 0 {
return fmt.Errorf("cannot specify both allowed_countries and blocked_countries")
}
// Validate DNS refresh interval
if g.DNSRefresh > 0 && time.Duration(g.DNSRefresh) < 30*time.Second {
return fmt.Errorf("dns_refresh must be at least 30s for stability, got %v", time.Duration(g.DNSRefresh))
}
return nil
}
// BanIP manually adds an IP to the ban list with a reason.
// This is the public API for external callers (like AdminHandler).
func (g *SIPGuardian) BanIP(ip, reason string) {
if g.IsWhitelisted(ip) {
g.logger.Info("Attempted to ban whitelisted IP", zap.String("ip", ip))
return
}
g.mu.Lock()
defer g.mu.Unlock()
// Create a failure tracker if needed (for hit count)
if _, exists := g.failureCounts[ip]; !exists {
g.failureCounts[ip] = &failureTracker{
count: g.MaxFailures,
firstSeen: time.Now(),
lastSeen: time.Now(),
}
} else {
g.failureCounts[ip].count = g.MaxFailures
}
g.banIP(ip, reason)
}
// Interface guards // Interface guards
var ( var (
_ caddy.Module = (*SIPGuardian)(nil) _ caddy.Module = (*SIPGuardian)(nil)
_ caddy.Provisioner = (*SIPGuardian)(nil) _ caddy.Provisioner = (*SIPGuardian)(nil)
_ caddy.CleanerUpper = (*SIPGuardian)(nil)
_ caddy.Validator = (*SIPGuardian)(nil)
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil) _ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
) )