Improve Caddy module lifecycle and safety
- Add Cleanup() method (caddy.CleanerUpper) to stop goroutines on config reload, preventing goroutine leaks - Add Validate() method (caddy.Validator) for early config validation with reasonable bounds checking - Add public BanIP() method for admin handler, replacing direct internal state manipulation - Add bounds checking for failure tracker and ban maps to prevent memory exhaustion under DDoS (100k/50k limits) - Add eviction functions to proactively clean oldest entries when at capacity
This commit is contained in:
parent
5cf34eb3c0
commit
265c606169
9
admin.go
9
admin.go
@ -206,13 +206,8 @@ func (h *AdminHandler) handleBan(w http.ResponseWriter, r *http.Request, path st
|
|||||||
body.Reason = "manual_ban"
|
body.Reason = "manual_ban"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force a ban by recording max failures
|
// Use public BanIP method for proper encapsulation
|
||||||
h.guardian.mu.Lock()
|
h.guardian.BanIP(ip, body.Reason)
|
||||||
h.guardian.failureCounts[ip] = &failureTracker{
|
|
||||||
count: h.guardian.MaxFailures,
|
|
||||||
}
|
|
||||||
h.guardian.banIP(ip, body.Reason)
|
|
||||||
h.guardian.mu.Unlock()
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
return json.NewEncoder(w).Encode(map[string]interface{}{
|
return json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
|||||||
258
sipguardian.go
258
sipguardian.go
@ -20,6 +20,13 @@ var (
|
|||||||
enableStorage = true
|
enableStorage = true
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Configuration limits to prevent unbounded growth under attack
|
||||||
|
const (
|
||||||
|
maxTrackedIPs = 100000 // Max IPs to track failures for
|
||||||
|
maxBannedIPs = 50000 // Max banned IPs to hold in memory
|
||||||
|
cleanupBatchSize = 1000 // Max entries to clean per cycle
|
||||||
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
caddy.RegisterModule(SIPGuardian{})
|
caddy.RegisterModule(SIPGuardian{})
|
||||||
}
|
}
|
||||||
@ -72,6 +79,10 @@ type SIPGuardian struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
storage *Storage
|
storage *Storage
|
||||||
geoIP *GeoIPLookup
|
geoIP *GeoIPLookup
|
||||||
|
|
||||||
|
// Lifecycle management
|
||||||
|
stopCh chan struct{}
|
||||||
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type failureTracker struct {
|
type failureTracker struct {
|
||||||
@ -93,6 +104,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
|||||||
g.logger = ctx.Logger()
|
g.logger = ctx.Logger()
|
||||||
g.bannedIPs = make(map[string]*BanEntry)
|
g.bannedIPs = make(map[string]*BanEntry)
|
||||||
g.failureCounts = make(map[string]*failureTracker)
|
g.failureCounts = make(map[string]*failureTracker)
|
||||||
|
g.stopCh = make(chan struct{})
|
||||||
|
|
||||||
// Set defaults
|
// Set defaults
|
||||||
if g.MaxFailures == 0 {
|
if g.MaxFailures == 0 {
|
||||||
@ -209,8 +221,9 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup goroutine
|
// Start cleanup goroutine with proper lifecycle tracking
|
||||||
go g.cleanupLoop(ctx)
|
g.wg.Add(1)
|
||||||
|
go g.cleanupLoop()
|
||||||
|
|
||||||
g.logger.Info("SIP Guardian initialized",
|
g.logger.Info("SIP Guardian initialized",
|
||||||
zap.Int("max_failures", g.MaxFailures),
|
zap.Int("max_failures", g.MaxFailures),
|
||||||
@ -341,6 +354,20 @@ func (g *SIPGuardian) RecordFailure(ip, reason string) bool {
|
|||||||
|
|
||||||
tracker, exists := g.failureCounts[ip]
|
tracker, exists := g.failureCounts[ip]
|
||||||
if !exists || now.Sub(tracker.firstSeen) > findWindow {
|
if !exists || now.Sub(tracker.firstSeen) > findWindow {
|
||||||
|
// Check bounds before adding new entry
|
||||||
|
if !exists && len(g.failureCounts) >= maxTrackedIPs {
|
||||||
|
// Proactively clean old entries to make room
|
||||||
|
g.evictOldestTrackers(cleanupBatchSize)
|
||||||
|
if len(g.failureCounts) >= maxTrackedIPs {
|
||||||
|
// Still at limit, log warning and skip tracking
|
||||||
|
g.logger.Warn("Failure tracker map at capacity, dropping new entry",
|
||||||
|
zap.String("ip", ip),
|
||||||
|
zap.Int("capacity", maxTrackedIPs),
|
||||||
|
)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start new tracking window
|
// Start new tracking window
|
||||||
tracker = &failureTracker{
|
tracker = &failureTracker{
|
||||||
count: 1,
|
count: 1,
|
||||||
@ -391,6 +418,19 @@ func (g *SIPGuardian) banIP(ip, reason string) {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
banDuration := time.Duration(g.BanTime)
|
banDuration := time.Duration(g.BanTime)
|
||||||
|
|
||||||
|
// Check bounds before adding new ban entry
|
||||||
|
if _, exists := g.bannedIPs[ip]; !exists && len(g.bannedIPs) >= maxBannedIPs {
|
||||||
|
// Proactively clean expired/oldest bans to make room
|
||||||
|
g.evictOldestBans(cleanupBatchSize)
|
||||||
|
if len(g.bannedIPs) >= maxBannedIPs {
|
||||||
|
// Still at limit, log warning but still ban (overwrite random existing)
|
||||||
|
g.logger.Warn("Ban map at capacity, evicting to make room",
|
||||||
|
zap.String("ip", ip),
|
||||||
|
zap.Int("capacity", maxBannedIPs),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
hitCount := 0
|
hitCount := 0
|
||||||
if tracker := g.failureCounts[ip]; tracker != nil {
|
if tracker := g.failureCounts[ip]; tracker != nil {
|
||||||
hitCount = tracker.count
|
hitCount = tracker.count
|
||||||
@ -529,13 +569,16 @@ func (g *SIPGuardian) RefreshDNSWhitelist() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// cleanupLoop periodically removes expired entries
|
// cleanupLoop periodically removes expired entries
|
||||||
func (g *SIPGuardian) cleanupLoop(ctx caddy.Context) {
|
func (g *SIPGuardian) cleanupLoop() {
|
||||||
|
defer g.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(1 * time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-g.stopCh:
|
||||||
|
g.logger.Debug("Cleanup loop stopped")
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
g.cleanup()
|
g.cleanup()
|
||||||
@ -564,6 +607,109 @@ func (g *SIPGuardian) cleanup() {
|
|||||||
delete(g.failureCounts, ip)
|
delete(g.failureCounts, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log map sizes for monitoring
|
||||||
|
if len(g.bannedIPs) > maxBannedIPs/2 || len(g.failureCounts) > maxTrackedIPs/2 {
|
||||||
|
g.logger.Info("Map size status",
|
||||||
|
zap.Int("banned_ips", len(g.bannedIPs)),
|
||||||
|
zap.Int("tracked_ips", len(g.failureCounts)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestTrackers removes the oldest failure trackers to make room (must hold lock)
|
||||||
|
func (g *SIPGuardian) evictOldestTrackers(count int) {
|
||||||
|
// Find and evict the oldest entries by firstSeen
|
||||||
|
type ipTime struct {
|
||||||
|
ip string
|
||||||
|
time time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect all entries
|
||||||
|
entries := make([]ipTime, 0, len(g.failureCounts))
|
||||||
|
for ip, tracker := range g.failureCounts {
|
||||||
|
entries = append(entries, ipTime{ip: ip, time: tracker.firstSeen})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by time (oldest first)
|
||||||
|
for i := 0; i < len(entries)-1; i++ {
|
||||||
|
for j := i + 1; j < len(entries); j++ {
|
||||||
|
if entries[j].time.Before(entries[i].time) {
|
||||||
|
entries[i], entries[j] = entries[j], entries[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evict oldest entries
|
||||||
|
evicted := 0
|
||||||
|
for _, entry := range entries {
|
||||||
|
if evicted >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
delete(g.failureCounts, entry.ip)
|
||||||
|
evicted++
|
||||||
|
}
|
||||||
|
|
||||||
|
if evicted > 0 {
|
||||||
|
g.logger.Info("Evicted oldest failure trackers",
|
||||||
|
zap.Int("evicted", evicted),
|
||||||
|
zap.Int("remaining", len(g.failureCounts)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evictOldestBans removes the oldest expired bans to make room (must hold lock)
|
||||||
|
func (g *SIPGuardian) evictOldestBans(count int) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// First, remove expired bans
|
||||||
|
for ip, entry := range g.bannedIPs {
|
||||||
|
if now.After(entry.ExpiresAt) {
|
||||||
|
delete(g.bannedIPs, ip)
|
||||||
|
count--
|
||||||
|
if count <= 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still need room, remove bans closest to expiry
|
||||||
|
if count > 0 {
|
||||||
|
type ipTime struct {
|
||||||
|
ip string
|
||||||
|
time time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
entries := make([]ipTime, 0, len(g.bannedIPs))
|
||||||
|
for ip, ban := range g.bannedIPs {
|
||||||
|
entries = append(entries, ipTime{ip: ip, time: ban.ExpiresAt})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by expiry time (soonest first)
|
||||||
|
for i := 0; i < len(entries)-1; i++ {
|
||||||
|
for j := i + 1; j < len(entries); j++ {
|
||||||
|
if entries[j].time.Before(entries[i].time) {
|
||||||
|
entries[i], entries[j] = entries[j], entries[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
evicted := 0
|
||||||
|
for _, entry := range entries {
|
||||||
|
if evicted >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
delete(g.bannedIPs, entry.ip)
|
||||||
|
evicted++
|
||||||
|
}
|
||||||
|
|
||||||
|
if evicted > 0 {
|
||||||
|
g.logger.Warn("Evicted active bans due to capacity limit",
|
||||||
|
zap.Int("evicted", evicted),
|
||||||
|
zap.Int("remaining", len(g.bannedIPs)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
|
// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
|
||||||
@ -742,9 +888,113 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cleanup implements caddy.CleanerUpper.
|
||||||
|
// Called when Caddy config is reloaded to stop goroutines and release resources.
|
||||||
|
func (g *SIPGuardian) Cleanup() error {
|
||||||
|
g.logger.Info("SIP Guardian cleanup starting")
|
||||||
|
|
||||||
|
// Signal all goroutines to stop
|
||||||
|
close(g.stopCh)
|
||||||
|
|
||||||
|
// Stop DNS whitelist background refresh
|
||||||
|
if g.dnsWhitelist != nil {
|
||||||
|
g.dnsWhitelist.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for goroutines to finish (with timeout)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
g.logger.Debug("All goroutines stopped cleanly")
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
g.logger.Warn("Timeout waiting for goroutines to stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close storage connection
|
||||||
|
if g.storage != nil {
|
||||||
|
if err := g.storage.Close(); err != nil {
|
||||||
|
g.logger.Error("Error closing storage", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close GeoIP database
|
||||||
|
if g.geoIP != nil {
|
||||||
|
g.geoIP.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.Info("SIP Guardian cleanup complete")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate implements caddy.Validator.
|
||||||
|
// Called after Provision() to validate configuration before use.
|
||||||
|
func (g *SIPGuardian) Validate() error {
|
||||||
|
if g.MaxFailures < 1 {
|
||||||
|
return fmt.Errorf("max_failures must be at least 1, got %d", g.MaxFailures)
|
||||||
|
}
|
||||||
|
if g.MaxFailures > 1000 {
|
||||||
|
return fmt.Errorf("max_failures exceeds reasonable limit (1000), got %d", g.MaxFailures)
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Duration(g.FindTime) < time.Second {
|
||||||
|
return fmt.Errorf("find_time must be at least 1s, got %v", time.Duration(g.FindTime))
|
||||||
|
}
|
||||||
|
if time.Duration(g.FindTime) > 24*time.Hour {
|
||||||
|
return fmt.Errorf("find_time exceeds reasonable limit (24h), got %v", time.Duration(g.FindTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Duration(g.BanTime) < time.Second {
|
||||||
|
return fmt.Errorf("ban_time must be at least 1s, got %v", time.Duration(g.BanTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate conflicting country configurations
|
||||||
|
if len(g.AllowedCountries) > 0 && len(g.BlockedCountries) > 0 {
|
||||||
|
return fmt.Errorf("cannot specify both allowed_countries and blocked_countries")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate DNS refresh interval
|
||||||
|
if g.DNSRefresh > 0 && time.Duration(g.DNSRefresh) < 30*time.Second {
|
||||||
|
return fmt.Errorf("dns_refresh must be at least 30s for stability, got %v", time.Duration(g.DNSRefresh))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BanIP manually adds an IP to the ban list with a reason.
|
||||||
|
// This is the public API for external callers (like AdminHandler).
|
||||||
|
func (g *SIPGuardian) BanIP(ip, reason string) {
|
||||||
|
if g.IsWhitelisted(ip) {
|
||||||
|
g.logger.Info("Attempted to ban whitelisted IP", zap.String("ip", ip))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
g.mu.Lock()
|
||||||
|
defer g.mu.Unlock()
|
||||||
|
|
||||||
|
// Create a failure tracker if needed (for hit count)
|
||||||
|
if _, exists := g.failureCounts[ip]; !exists {
|
||||||
|
g.failureCounts[ip] = &failureTracker{
|
||||||
|
count: g.MaxFailures,
|
||||||
|
firstSeen: time.Now(),
|
||||||
|
lastSeen: time.Now(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
g.failureCounts[ip].count = g.MaxFailures
|
||||||
|
}
|
||||||
|
|
||||||
|
g.banIP(ip, reason)
|
||||||
|
}
|
||||||
|
|
||||||
// Interface guards
|
// Interface guards
|
||||||
var (
|
var (
|
||||||
_ caddy.Module = (*SIPGuardian)(nil)
|
_ caddy.Module = (*SIPGuardian)(nil)
|
||||||
_ caddy.Provisioner = (*SIPGuardian)(nil)
|
_ caddy.Provisioner = (*SIPGuardian)(nil)
|
||||||
|
_ caddy.CleanerUpper = (*SIPGuardian)(nil)
|
||||||
|
_ caddy.Validator = (*SIPGuardian)(nil)
|
||||||
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
|
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user