Improve Caddy module lifecycle and safety
- Add Cleanup() method (caddy.CleanerUpper) to stop goroutines on config reload, preventing goroutine leaks - Add Validate() method (caddy.Validator) for early config validation with reasonable bounds checking - Add public BanIP() method for admin handler, replacing direct internal state manipulation - Add bounds checking for failure tracker and ban maps to prevent memory exhaustion under DDoS (100k/50k limits) - Add eviction functions to proactively clean oldest entries when at capacity
This commit is contained in:
parent
5cf34eb3c0
commit
265c606169
9
admin.go
9
admin.go
@ -206,13 +206,8 @@ func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path st
|
||||
body.Reason = "manual_ban"
|
||||
}
|
||||
|
||||
// Force a ban by recording max failures
|
||||
h.guardian.mu.Lock()
|
||||
h.guardian.failureCounts[ip] = &failureTracker{
|
||||
count: h.guardian.MaxFailures,
|
||||
}
|
||||
h.guardian.banIP(ip, body.Reason)
|
||||
h.guardian.mu.Unlock()
|
||||
// Use public BanIP method for proper encapsulation
|
||||
h.guardian.BanIP(ip, body.Reason)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
return json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
|
||||
258
sipguardian.go
258
sipguardian.go
@ -20,6 +20,13 @@ var (
|
||||
enableStorage = true
|
||||
)
|
||||
|
||||
// Configuration limits to prevent unbounded growth under attack
|
||||
const (
|
||||
maxTrackedIPs = 100000 // Max IPs to track failures for
|
||||
maxBannedIPs = 50000 // Max banned IPs to hold in memory
|
||||
cleanupBatchSize = 1000 // Max entries to clean per cycle
|
||||
)
|
||||
|
||||
func init() {
|
||||
caddy.RegisterModule(SIPGuardian{})
|
||||
}
|
||||
@ -72,6 +79,10 @@ type SIPGuardian struct {
|
||||
mu sync.RWMutex
|
||||
storage *Storage
|
||||
geoIP *GeoIPLookup
|
||||
|
||||
// Lifecycle management
|
||||
stopCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type failureTracker struct {
|
||||
@ -93,6 +104,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
||||
g.logger = ctx.Logger()
|
||||
g.bannedIPs = make(map[string]*BanEntry)
|
||||
g.failureCounts = make(map[string]*failureTracker)
|
||||
g.stopCh = make(chan struct{})
|
||||
|
||||
// Set defaults
|
||||
if g.MaxFailures == 0 {
|
||||
@ -209,8 +221,9 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
||||
)
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go g.cleanupLoop(ctx)
|
||||
// Start cleanup goroutine with proper lifecycle tracking
|
||||
g.wg.Add(1)
|
||||
go g.cleanupLoop()
|
||||
|
||||
g.logger.Info("SIP Guardian initialized",
|
||||
zap.Int("max_failures", g.MaxFailures),
|
||||
@ -341,6 +354,20 @@ func (g *SIPGuardian) RecordFailure(ip, reason string) bool {
|
||||
|
||||
tracker, exists := g.failureCounts[ip]
|
||||
if !exists || now.Sub(tracker.firstSeen) > findWindow {
|
||||
// Check bounds before adding new entry
|
||||
if !exists && len(g.failureCounts) >= maxTrackedIPs {
|
||||
// Proactively clean old entries to make room
|
||||
g.evictOldestTrackers(cleanupBatchSize)
|
||||
if len(g.failureCounts) >= maxTrackedIPs {
|
||||
// Still at limit, log warning and skip tracking
|
||||
g.logger.Warn("Failure tracker map at capacity, dropping new entry",
|
||||
zap.String("ip", ip),
|
||||
zap.Int("capacity", maxTrackedIPs),
|
||||
)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Start new tracking window
|
||||
tracker = &failureTracker{
|
||||
count: 1,
|
||||
@ -391,6 +418,19 @@ func (g *SIPGuardian) banIP(ip, reason string) {
|
||||
now := time.Now()
|
||||
banDuration := time.Duration(g.BanTime)
|
||||
|
||||
// Check bounds before adding new ban entry
|
||||
if _, exists := g.bannedIPs[ip]; !exists && len(g.bannedIPs) >= maxBannedIPs {
|
||||
// Proactively clean expired/oldest bans to make room
|
||||
g.evictOldestBans(cleanupBatchSize)
|
||||
if len(g.bannedIPs) >= maxBannedIPs {
|
||||
// Still at limit, log warning but still ban (overwrite random existing)
|
||||
g.logger.Warn("Ban map at capacity, evicting to make room",
|
||||
zap.String("ip", ip),
|
||||
zap.Int("capacity", maxBannedIPs),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
hitCount := 0
|
||||
if tracker := g.failureCounts[ip]; tracker != nil {
|
||||
hitCount = tracker.count
|
||||
@ -529,13 +569,16 @@ func (g *SIPGuardian) RefreshDNSWhitelist() {
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes expired entries
|
||||
func (g *SIPGuardian) cleanupLoop(ctx caddy.Context) {
|
||||
func (g *SIPGuardian) cleanupLoop() {
|
||||
defer g.wg.Done()
|
||||
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-g.stopCh:
|
||||
g.logger.Debug("Cleanup loop stopped")
|
||||
return
|
||||
case <-ticker.C:
|
||||
g.cleanup()
|
||||
@ -564,6 +607,109 @@ func (g *SIPGuardian) cleanup() {
|
||||
delete(g.failureCounts, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// Log map sizes for monitoring
|
||||
if len(g.bannedIPs) > maxBannedIPs/2 || len(g.failureCounts) > maxTrackedIPs/2 {
|
||||
g.logger.Info("Map size status",
|
||||
zap.Int("banned_ips", len(g.bannedIPs)),
|
||||
zap.Int("tracked_ips", len(g.failureCounts)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestTrackers removes the oldest failure trackers to make room (must hold lock)
|
||||
func (g *SIPGuardian) evictOldestTrackers(count int) {
|
||||
// Find and evict the oldest entries by firstSeen
|
||||
type ipTime struct {
|
||||
ip string
|
||||
time time.Time
|
||||
}
|
||||
|
||||
// Collect all entries
|
||||
entries := make([]ipTime, 0, len(g.failureCounts))
|
||||
for ip, tracker := range g.failureCounts {
|
||||
entries = append(entries, ipTime{ip: ip, time: tracker.firstSeen})
|
||||
}
|
||||
|
||||
// Sort by time (oldest first)
|
||||
for i := 0; i < len(entries)-1; i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if entries[j].time.Before(entries[i].time) {
|
||||
entries[i], entries[j] = entries[j], entries[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Evict oldest entries
|
||||
evicted := 0
|
||||
for _, entry := range entries {
|
||||
if evicted >= count {
|
||||
break
|
||||
}
|
||||
delete(g.failureCounts, entry.ip)
|
||||
evicted++
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
g.logger.Info("Evicted oldest failure trackers",
|
||||
zap.Int("evicted", evicted),
|
||||
zap.Int("remaining", len(g.failureCounts)),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestBans removes the oldest expired bans to make room (must hold lock)
|
||||
func (g *SIPGuardian) evictOldestBans(count int) {
|
||||
now := time.Now()
|
||||
|
||||
// First, remove expired bans
|
||||
for ip, entry := range g.bannedIPs {
|
||||
if now.After(entry.ExpiresAt) {
|
||||
delete(g.bannedIPs, ip)
|
||||
count--
|
||||
if count <= 0 {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If still need room, remove bans closest to expiry
|
||||
if count > 0 {
|
||||
type ipTime struct {
|
||||
ip string
|
||||
time time.Time
|
||||
}
|
||||
|
||||
entries := make([]ipTime, 0, len(g.bannedIPs))
|
||||
for ip, ban := range g.bannedIPs {
|
||||
entries = append(entries, ipTime{ip: ip, time: ban.ExpiresAt})
|
||||
}
|
||||
|
||||
// Sort by expiry time (soonest first)
|
||||
for i := 0; i < len(entries)-1; i++ {
|
||||
for j := i + 1; j < len(entries); j++ {
|
||||
if entries[j].time.Before(entries[i].time) {
|
||||
entries[i], entries[j] = entries[j], entries[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
evicted := 0
|
||||
for _, entry := range entries {
|
||||
if evicted >= count {
|
||||
break
|
||||
}
|
||||
delete(g.bannedIPs, entry.ip)
|
||||
evicted++
|
||||
}
|
||||
|
||||
if evicted > 0 {
|
||||
g.logger.Warn("Evicted active bans due to capacity limit",
|
||||
zap.Int("evicted", evicted),
|
||||
zap.Int("remaining", len(g.bannedIPs)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
|
||||
@ -742,9 +888,113 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cleanup implements caddy.CleanerUpper.
|
||||
// Called when Caddy config is reloaded to stop goroutines and release resources.
|
||||
func (g *SIPGuardian) Cleanup() error {
|
||||
g.logger.Info("SIP Guardian cleanup starting")
|
||||
|
||||
// Signal all goroutines to stop
|
||||
close(g.stopCh)
|
||||
|
||||
// Stop DNS whitelist background refresh
|
||||
if g.dnsWhitelist != nil {
|
||||
g.dnsWhitelist.Stop()
|
||||
}
|
||||
|
||||
// Wait for goroutines to finish (with timeout)
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
g.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
g.logger.Debug("All goroutines stopped cleanly")
|
||||
case <-time.After(5 * time.Second):
|
||||
g.logger.Warn("Timeout waiting for goroutines to stop")
|
||||
}
|
||||
|
||||
// Close storage connection
|
||||
if g.storage != nil {
|
||||
if err := g.storage.Close(); err != nil {
|
||||
g.logger.Error("Error closing storage", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// Close GeoIP database
|
||||
if g.geoIP != nil {
|
||||
g.geoIP.Close()
|
||||
}
|
||||
|
||||
g.logger.Info("SIP Guardian cleanup complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate implements caddy.Validator.
|
||||
// Called after Provision() to validate configuration before use.
|
||||
func (g *SIPGuardian) Validate() error {
|
||||
if g.MaxFailures < 1 {
|
||||
return fmt.Errorf("max_failures must be at least 1, got %d", g.MaxFailures)
|
||||
}
|
||||
if g.MaxFailures > 1000 {
|
||||
return fmt.Errorf("max_failures exceeds reasonable limit (1000), got %d", g.MaxFailures)
|
||||
}
|
||||
|
||||
if time.Duration(g.FindTime) < time.Second {
|
||||
return fmt.Errorf("find_time must be at least 1s, got %v", time.Duration(g.FindTime))
|
||||
}
|
||||
if time.Duration(g.FindTime) > 24*time.Hour {
|
||||
return fmt.Errorf("find_time exceeds reasonable limit (24h), got %v", time.Duration(g.FindTime))
|
||||
}
|
||||
|
||||
if time.Duration(g.BanTime) < time.Second {
|
||||
return fmt.Errorf("ban_time must be at least 1s, got %v", time.Duration(g.BanTime))
|
||||
}
|
||||
|
||||
// Validate conflicting country configurations
|
||||
if len(g.AllowedCountries) > 0 && len(g.BlockedCountries) > 0 {
|
||||
return fmt.Errorf("cannot specify both allowed_countries and blocked_countries")
|
||||
}
|
||||
|
||||
// Validate DNS refresh interval
|
||||
if g.DNSRefresh > 0 && time.Duration(g.DNSRefresh) < 30*time.Second {
|
||||
return fmt.Errorf("dns_refresh must be at least 30s for stability, got %v", time.Duration(g.DNSRefresh))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BanIP manually adds an IP to the ban list with a reason.
|
||||
// This is the public API for external callers (like AdminHandler).
|
||||
func (g *SIPGuardian) BanIP(ip, reason string) {
|
||||
if g.IsWhitelisted(ip) {
|
||||
g.logger.Info("Attempted to ban whitelisted IP", zap.String("ip", ip))
|
||||
return
|
||||
}
|
||||
|
||||
g.mu.Lock()
|
||||
defer g.mu.Unlock()
|
||||
|
||||
// Create a failure tracker if needed (for hit count)
|
||||
if _, exists := g.failureCounts[ip]; !exists {
|
||||
g.failureCounts[ip] = &failureTracker{
|
||||
count: g.MaxFailures,
|
||||
firstSeen: time.Now(),
|
||||
lastSeen: time.Now(),
|
||||
}
|
||||
} else {
|
||||
g.failureCounts[ip].count = g.MaxFailures
|
||||
}
|
||||
|
||||
g.banIP(ip, reason)
|
||||
}
|
||||
|
||||
// Interface guards
|
||||
var (
|
||||
_ caddy.Module = (*SIPGuardian)(nil)
|
||||
_ caddy.Provisioner = (*SIPGuardian)(nil)
|
||||
_ caddy.CleanerUpper = (*SIPGuardian)(nil)
|
||||
_ caddy.Validator = (*SIPGuardian)(nil)
|
||||
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user