caddy-sip-guardian/sipguardian.go
Ryan Malloy 265c606169 Improve Caddy module lifecycle and safety
- Add Cleanup() method (caddy.CleanerUpper) to stop goroutines on config
  reload, preventing goroutine leaks
- Add Validate() method (caddy.Validator) for early config validation with
  reasonable bounds checking
- Add public BanIP() method for admin handler, replacing direct internal
  state manipulation
- Add bounds checking for failure tracker and ban maps to prevent memory
  exhaustion under DDoS (100k/50k limits)
- Add eviction functions to proactively clean oldest entries when at capacity
2025-12-08 01:29:16 -07:00

1001 lines
25 KiB
Go

// Package sipguardian provides a Caddy module for SIP-aware rate limiting and IP banning.
// It integrates with caddy-l4 for Layer 4 proxying and caddy-ratelimit for rate limiting.
package sipguardian
import (
"fmt"
"net"
"sync"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"go.uber.org/zap"
)
// Feature flags for optional components
var (
enableMetrics = true
enableWebhooks = true
enableStorage = true
)
// Configuration limits to prevent unbounded growth under attack
const (
maxTrackedIPs = 100000 // Max IPs to track failures for
maxBannedIPs = 50000 // Max banned IPs to hold in memory
cleanupBatchSize = 1000 // Max entries to clean per cycle
)
func init() {
caddy.RegisterModule(SIPGuardian{})
}
// BanEntry represents a banned IP with metadata
type BanEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BannedAt time.Time `json:"banned_at"`
ExpiresAt time.Time `json:"expires_at"`
HitCount int `json:"hit_count"`
}
// SIPGuardian implements intelligent SIP protection at Layer 4
type SIPGuardian struct {
// Configuration
MaxFailures int `json:"max_failures,omitempty"`
FindTime caddy.Duration `json:"find_time,omitempty"`
BanTime caddy.Duration `json:"ban_time,omitempty"`
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
// DNS-aware whitelist configuration
WhitelistHosts []string `json:"whitelist_hosts,omitempty"` // Hostnames to resolve (A/AAAA)
WhitelistSRV []string `json:"whitelist_srv,omitempty"` // SRV records to resolve
DNSRefresh caddy.Duration `json:"dns_refresh,omitempty"` // DNS refresh interval (default: 5m)
// Webhook configuration
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
// Storage configuration
StoragePath string `json:"storage_path,omitempty"`
// GeoIP configuration
GeoIPPath string `json:"geoip_path,omitempty"`
BlockedCountries []string `json:"blocked_countries,omitempty"`
AllowedCountries []string `json:"allowed_countries,omitempty"`
// Enumeration detection configuration
Enumeration *EnumerationConfig `json:"enumeration,omitempty"`
// Validation configuration
Validation *ValidationConfig `json:"validation,omitempty"`
// Runtime state
logger *zap.Logger
bannedIPs map[string]*BanEntry
failureCounts map[string]*failureTracker
whitelistNets []*net.IPNet
dnsWhitelist *DNSWhitelist
mu sync.RWMutex
storage *Storage
geoIP *GeoIPLookup
// Lifecycle management
stopCh chan struct{}
wg sync.WaitGroup
}
type failureTracker struct {
count int
firstSeen time.Time
lastSeen time.Time
}
// CaddyModule returns the Caddy module information.
func (SIPGuardian) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "sip_guardian",
New: func() caddy.Module { return new(SIPGuardian) },
}
}
// Provision sets up the module.
func (g *SIPGuardian) Provision(ctx caddy.Context) error {
g.logger = ctx.Logger()
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
g.stopCh = make(chan struct{})
// Set defaults
if g.MaxFailures == 0 {
g.MaxFailures = 5
}
if g.FindTime == 0 {
g.FindTime = caddy.Duration(10 * time.Minute)
}
if g.BanTime == 0 {
g.BanTime = caddy.Duration(1 * time.Hour)
}
// Parse whitelist CIDRs
for _, cidr := range g.WhitelistCIDR {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid whitelist CIDR %s: %v", cidr, err)
}
g.whitelistNets = append(g.whitelistNets, network)
}
// Initialize metrics
if enableMetrics {
RegisterMetrics()
}
// Initialize webhooks
if enableWebhooks && len(g.Webhooks) > 0 {
wm := GetWebhookManager(g.logger)
for _, config := range g.Webhooks {
wm.AddWebhook(config)
}
}
// Initialize persistent storage
if enableStorage && g.StoragePath != "" {
storage, err := InitStorage(g.logger, StorageConfig{
Path: g.StoragePath,
})
if err != nil {
g.logger.Warn("Failed to initialize storage, continuing without persistence",
zap.Error(err),
)
} else {
g.storage = storage
// Load existing bans from storage
if err := g.loadBansFromStorage(); err != nil {
g.logger.Warn("Failed to load bans from storage", zap.Error(err))
}
}
}
// Initialize GeoIP if configured
if g.GeoIPPath != "" {
geoIP, err := NewGeoIPLookup(g.GeoIPPath)
if err != nil {
g.logger.Warn("Failed to initialize GeoIP, country blocking disabled",
zap.Error(err),
)
} else {
g.geoIP = geoIP
g.logger.Info("GeoIP initialized",
zap.Int("blocked_countries", len(g.BlockedCountries)),
zap.Int("allowed_countries", len(g.AllowedCountries)),
)
}
}
// Initialize DNS whitelist if configured
if len(g.WhitelistHosts) > 0 || len(g.WhitelistSRV) > 0 {
refreshInterval := 5 * time.Minute
if g.DNSRefresh > 0 {
refreshInterval = time.Duration(g.DNSRefresh)
}
g.dnsWhitelist = NewDNSWhitelist(DNSWhitelistConfig{
Hostnames: g.WhitelistHosts,
SRVRecords: g.WhitelistSRV,
RefreshInterval: refreshInterval,
AllowStale: true,
ResolveTimeout: 10 * time.Second,
}, g.logger)
if err := g.dnsWhitelist.Start(); err != nil {
g.logger.Warn("Failed to initialize DNS whitelist",
zap.Error(err),
)
} else {
g.logger.Info("DNS whitelist initialized",
zap.Int("hostnames", len(g.WhitelistHosts)),
zap.Int("srv_records", len(g.WhitelistSRV)),
zap.Duration("refresh_interval", refreshInterval),
)
}
}
// Initialize enumeration detection with config if specified
if g.Enumeration != nil {
SetEnumerationConfig(*g.Enumeration)
g.logger.Info("Enumeration detection configured",
zap.Int("max_extensions", g.Enumeration.MaxExtensions),
zap.Int("sequential_threshold", g.Enumeration.SequentialThreshold),
zap.Duration("extension_window", g.Enumeration.ExtensionWindow),
)
}
// Initialize validation with config if specified
if g.Validation != nil {
SetValidationConfig(*g.Validation)
g.logger.Info("SIP validation configured",
zap.String("mode", string(g.Validation.Mode)),
zap.Bool("enabled", g.Validation.Enabled),
zap.Int("max_message_size", g.Validation.MaxMessageSize),
)
}
// Start cleanup goroutine with proper lifecycle tracking
g.wg.Add(1)
go g.cleanupLoop()
g.logger.Info("SIP Guardian initialized",
zap.Int("max_failures", g.MaxFailures),
zap.Duration("find_time", time.Duration(g.FindTime)),
zap.Duration("ban_time", time.Duration(g.BanTime)),
zap.Int("whitelist_count", len(g.whitelistNets)),
zap.Bool("storage_enabled", g.storage != nil),
zap.Bool("geoip_enabled", g.geoIP != nil),
zap.Int("webhook_count", len(g.Webhooks)),
zap.Bool("enumeration_enabled", g.Enumeration != nil),
zap.Bool("validation_enabled", g.Validation != nil && g.Validation.Enabled),
)
return nil
}
// loadBansFromStorage loads active bans from persistent storage
func (g *SIPGuardian) loadBansFromStorage() error {
if g.storage == nil {
return nil
}
bans, err := g.storage.LoadActiveBans()
if err != nil {
return err
}
g.mu.Lock()
defer g.mu.Unlock()
for _, ban := range bans {
entry := ban // Create a copy
g.bannedIPs[entry.IP] = &entry
}
g.logger.Info("Loaded bans from storage", zap.Int("count", len(bans)))
return nil
}
// IsWhitelisted checks if an IP is in the whitelist (CIDR or DNS-based)
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
// Check CIDR-based whitelist
for _, network := range g.whitelistNets {
if network.Contains(parsedIP) {
if enableMetrics {
RecordWhitelistedConnection()
}
return true
}
}
// Check DNS-based whitelist
if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) {
if enableMetrics {
RecordWhitelistedConnection()
}
g.logger.Debug("IP whitelisted via DNS",
zap.String("ip", ip),
zap.String("source", g.dnsWhitelist.GetSource(ip).Source),
)
return true
}
return false
}
// IsCountryBlocked checks if an IP's country is blocked (or not in allowed list)
func (g *SIPGuardian) IsCountryBlocked(ip string) (bool, string) {
if g.geoIP == nil {
return false, ""
}
country, err := g.geoIP.LookupCountry(ip)
if err != nil {
g.logger.Debug("GeoIP lookup failed", zap.String("ip", ip), zap.Error(err))
return false, ""
}
// If allowed countries are specified, only those are allowed
if len(g.AllowedCountries) > 0 {
for _, allowed := range g.AllowedCountries {
if country == allowed {
return false, country
}
}
return true, country // Not in allowed list
}
// Check blocked countries
for _, blocked := range g.BlockedCountries {
if country == blocked {
return true, country
}
}
return false, country
}
// IsBanned checks if an IP is currently banned
func (g *SIPGuardian) IsBanned(ip string) bool {
g.mu.RLock()
defer g.mu.RUnlock()
if entry, exists := g.bannedIPs[ip]; exists {
if time.Now().Before(entry.ExpiresAt) {
return true
}
}
return false
}
// RecordFailure records a failed authentication attempt
func (g *SIPGuardian) RecordFailure(ip, reason string) bool {
if g.IsWhitelisted(ip) {
return false
}
g.mu.Lock()
defer g.mu.Unlock()
now := time.Now()
findWindow := time.Duration(g.FindTime)
tracker, exists := g.failureCounts[ip]
if !exists || now.Sub(tracker.firstSeen) > findWindow {
// Check bounds before adding new entry
if !exists && len(g.failureCounts) >= maxTrackedIPs {
// Proactively clean old entries to make room
g.evictOldestTrackers(cleanupBatchSize)
if len(g.failureCounts) >= maxTrackedIPs {
// Still at limit, log warning and skip tracking
g.logger.Warn("Failure tracker map at capacity, dropping new entry",
zap.String("ip", ip),
zap.Int("capacity", maxTrackedIPs),
)
return false
}
}
// Start new tracking window
tracker = &failureTracker{
count: 1,
firstSeen: now,
lastSeen: now,
}
g.failureCounts[ip] = tracker
} else {
tracker.count++
tracker.lastSeen = now
}
g.logger.Debug("Failure recorded",
zap.String("ip", ip),
zap.String("reason", reason),
zap.Int("count", tracker.count),
)
// Record metrics
if enableMetrics {
RecordFailure(reason)
UpdateTrackedIPs(len(g.failureCounts))
}
// Record in storage (async)
if g.storage != nil {
go func() {
g.storage.RecordFailure(ip, reason, nil)
}()
}
// Emit failure event via webhook
if enableWebhooks {
EmitFailureEvent(g.logger, ip, reason, tracker.count)
}
// Check if we should ban
if tracker.count >= g.MaxFailures {
g.banIP(ip, reason)
return true
}
return false
}
// banIP adds an IP to the ban list (must hold lock)
func (g *SIPGuardian) banIP(ip, reason string) {
now := time.Now()
banDuration := time.Duration(g.BanTime)
// Check bounds before adding new ban entry
if _, exists := g.bannedIPs[ip]; !exists && len(g.bannedIPs) >= maxBannedIPs {
// Proactively clean expired/oldest bans to make room
g.evictOldestBans(cleanupBatchSize)
if len(g.bannedIPs) >= maxBannedIPs {
// Still at limit, log warning but still ban (overwrite random existing)
g.logger.Warn("Ban map at capacity, evicting to make room",
zap.String("ip", ip),
zap.Int("capacity", maxBannedIPs),
)
}
}
hitCount := 0
if tracker := g.failureCounts[ip]; tracker != nil {
hitCount = tracker.count
}
entry := &BanEntry{
IP: ip,
Reason: reason,
BannedAt: now,
ExpiresAt: now.Add(banDuration),
HitCount: hitCount,
}
g.bannedIPs[ip] = entry
// Clear failure counter
delete(g.failureCounts, ip)
g.logger.Warn("IP banned",
zap.String("ip", ip),
zap.String("reason", reason),
zap.Duration("duration", banDuration),
)
// Record metrics
if enableMetrics {
RecordBan()
}
// Save to persistent storage
if g.storage != nil {
go func() {
if err := g.storage.SaveBan(entry); err != nil {
g.logger.Error("Failed to save ban to storage", zap.Error(err))
}
}()
}
// Emit webhook event
if enableWebhooks {
EmitBanEvent(g.logger, entry)
}
}
// UnbanIP manually removes an IP from the ban list
func (g *SIPGuardian) UnbanIP(ip string) bool {
g.mu.Lock()
defer g.mu.Unlock()
if entry, exists := g.bannedIPs[ip]; exists {
// Record ban duration for metrics
if enableMetrics {
duration := time.Since(entry.BannedAt).Seconds()
RecordBanDuration(duration)
RecordUnban()
}
delete(g.bannedIPs, ip)
g.logger.Info("IP unbanned", zap.String("ip", ip))
// Update storage
if g.storage != nil {
go func() {
if err := g.storage.RemoveBan(ip, "manual_unban"); err != nil {
g.logger.Error("Failed to update storage on unban", zap.Error(err))
}
}()
}
// Emit webhook event
if enableWebhooks {
EmitUnbanEvent(g.logger, ip, "manual_unban")
}
return true
}
return false
}
// GetBannedIPs returns a list of currently banned IPs
func (g *SIPGuardian) GetBannedIPs() []BanEntry {
g.mu.RLock()
defer g.mu.RUnlock()
var entries []BanEntry
now := time.Now()
for _, entry := range g.bannedIPs {
if now.Before(entry.ExpiresAt) {
entries = append(entries, *entry)
}
}
return entries
}
// GetStats returns current statistics
func (g *SIPGuardian) GetStats() map[string]interface{} {
g.mu.RLock()
defer g.mu.RUnlock()
activeBans := 0
now := time.Now()
for _, entry := range g.bannedIPs {
if now.Before(entry.ExpiresAt) {
activeBans++
}
}
stats := map[string]interface{}{
"active_bans": activeBans,
"tracked_failures": len(g.failureCounts),
"whitelist_cidr": len(g.whitelistNets),
"whitelist_hosts": len(g.WhitelistHosts),
"whitelist_srv": len(g.WhitelistSRV),
}
// Add DNS whitelist stats if available
if g.dnsWhitelist != nil {
stats["dns_whitelist"] = g.dnsWhitelist.Stats()
}
return stats
}
// GetDNSWhitelistEntries returns all resolved DNS whitelist entries
func (g *SIPGuardian) GetDNSWhitelistEntries() []ResolvedEntry {
if g.dnsWhitelist == nil {
return nil
}
return g.dnsWhitelist.GetResolvedIPs()
}
// RefreshDNSWhitelist forces an immediate refresh of DNS whitelist entries
func (g *SIPGuardian) RefreshDNSWhitelist() {
if g.dnsWhitelist != nil {
g.dnsWhitelist.ForceRefresh()
}
}
// cleanupLoop periodically removes expired entries
func (g *SIPGuardian) cleanupLoop() {
defer g.wg.Done()
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-g.stopCh:
g.logger.Debug("Cleanup loop stopped")
return
case <-ticker.C:
g.cleanup()
}
}
}
func (g *SIPGuardian) cleanup() {
g.mu.Lock()
defer g.mu.Unlock()
now := time.Now()
findWindow := time.Duration(g.FindTime)
// Cleanup expired bans
for ip, entry := range g.bannedIPs {
if now.After(entry.ExpiresAt) {
delete(g.bannedIPs, ip)
g.logger.Debug("Ban expired", zap.String("ip", ip))
}
}
// Cleanup old failure trackers
for ip, tracker := range g.failureCounts {
if now.Sub(tracker.firstSeen) > findWindow {
delete(g.failureCounts, ip)
}
}
// Log map sizes for monitoring
if len(g.bannedIPs) > maxBannedIPs/2 || len(g.failureCounts) > maxTrackedIPs/2 {
g.logger.Info("Map size status",
zap.Int("banned_ips", len(g.bannedIPs)),
zap.Int("tracked_ips", len(g.failureCounts)),
)
}
}
// evictOldestTrackers removes the oldest failure trackers to make room (must hold lock)
func (g *SIPGuardian) evictOldestTrackers(count int) {
// Find and evict the oldest entries by firstSeen
type ipTime struct {
ip string
time time.Time
}
// Collect all entries
entries := make([]ipTime, 0, len(g.failureCounts))
for ip, tracker := range g.failureCounts {
entries = append(entries, ipTime{ip: ip, time: tracker.firstSeen})
}
// Sort by time (oldest first)
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].time.Before(entries[i].time) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
// Evict oldest entries
evicted := 0
for _, entry := range entries {
if evicted >= count {
break
}
delete(g.failureCounts, entry.ip)
evicted++
}
if evicted > 0 {
g.logger.Info("Evicted oldest failure trackers",
zap.Int("evicted", evicted),
zap.Int("remaining", len(g.failureCounts)),
)
}
}
// evictOldestBans removes the oldest expired bans to make room (must hold lock)
func (g *SIPGuardian) evictOldestBans(count int) {
now := time.Now()
// First, remove expired bans
for ip, entry := range g.bannedIPs {
if now.After(entry.ExpiresAt) {
delete(g.bannedIPs, ip)
count--
if count <= 0 {
return
}
}
}
// If still need room, remove bans closest to expiry
if count > 0 {
type ipTime struct {
ip string
time time.Time
}
entries := make([]ipTime, 0, len(g.bannedIPs))
for ip, ban := range g.bannedIPs {
entries = append(entries, ipTime{ip: ip, time: ban.ExpiresAt})
}
// Sort by expiry time (soonest first)
for i := 0; i < len(entries)-1; i++ {
for j := i + 1; j < len(entries); j++ {
if entries[j].time.Before(entries[i].time) {
entries[i], entries[j] = entries[j], entries[i]
}
}
}
evicted := 0
for _, entry := range entries {
if evicted >= count {
break
}
delete(g.bannedIPs, entry.ip)
evicted++
}
if evicted > 0 {
g.logger.Warn("Evicted active bans due to capacity limit",
zap.Int("evicted", evicted),
zap.Int("remaining", len(g.bannedIPs)),
)
}
}
}
// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
// Extended configuration options:
//
// sip_guardian {
// max_failures 5
// find_time 10m
// ban_time 1h
//
// # IP/CIDR whitelist (static)
// whitelist 10.0.0.0/8 192.168.0.0/16
//
// # DNS-aware whitelist (dynamic, auto-refreshed)
// whitelist_hosts pbx.example.com trunk.provider.com
// whitelist_srv _sip._udp.provider.com _sip._tcp.carrier.net
// dns_refresh 5m # How often to refresh DNS (default: 5m)
//
// # Persistent storage
// storage /var/lib/sip-guardian/guardian.db
//
// # GeoIP blocking (requires MaxMind database)
// geoip_db /path/to/GeoLite2-Country.mmdb
// block_countries CN RU KP
// allow_countries US CA GB # Alternative: only allow these
//
// # Webhook notifications
// webhook https://example.com/hook {
// events ban unban suspicious
// secret my-webhook-secret
// timeout 10s
// }
// }
func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
for d.NextBlock(0) {
switch d.Val() {
case "max_failures":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid max_failures: %v", err)
}
g.MaxFailures = val
case "find_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid find_time: %v", err)
}
g.FindTime = caddy.Duration(dur)
case "ban_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid ban_time: %v", err)
}
g.BanTime = caddy.Duration(dur)
case "whitelist":
// Legacy: CIDR-only whitelist
for d.NextArg() {
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
}
case "whitelist_hosts":
// DNS A/AAAA record whitelist (hostnames resolved to IPs)
for d.NextArg() {
g.WhitelistHosts = append(g.WhitelistHosts, d.Val())
}
case "whitelist_srv":
// DNS SRV record whitelist (e.g., _sip._udp.provider.com)
for d.NextArg() {
g.WhitelistSRV = append(g.WhitelistSRV, d.Val())
}
case "dns_refresh":
// Interval for refreshing DNS-based whitelist entries
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid dns_refresh: %v", err)
}
g.DNSRefresh = caddy.Duration(dur)
case "storage":
if !d.NextArg() {
return d.ArgErr()
}
g.StoragePath = d.Val()
case "geoip_db":
if !d.NextArg() {
return d.ArgErr()
}
g.GeoIPPath = d.Val()
case "block_countries":
for d.NextArg() {
country := d.Val()
// Support continent expansion (e.g., "AS" for all of Asia)
if expanded := ExpandContinentCode(country); expanded != nil {
g.BlockedCountries = append(g.BlockedCountries, expanded...)
} else {
g.BlockedCountries = append(g.BlockedCountries, country)
}
}
case "allow_countries":
for d.NextArg() {
country := d.Val()
if expanded := ExpandContinentCode(country); expanded != nil {
g.AllowedCountries = append(g.AllowedCountries, expanded...)
} else {
g.AllowedCountries = append(g.AllowedCountries, country)
}
}
case "webhook":
if !d.NextArg() {
return d.ArgErr()
}
webhook := WebhookConfig{
URL: d.Val(),
}
// Parse webhook block if present
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "events":
webhook.Events = d.RemainingArgs()
case "secret":
if !d.NextArg() {
return d.ArgErr()
}
webhook.Secret = d.Val()
case "timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid webhook timeout: %v", err)
}
webhook.Timeout = dur
case "header":
args := d.RemainingArgs()
if len(args) != 2 {
return d.Errf("header requires name and value")
}
if webhook.Headers == nil {
webhook.Headers = make(map[string]string)
}
webhook.Headers[args[0]] = args[1]
default:
return d.Errf("unknown webhook directive: %s", d.Val())
}
}
g.Webhooks = append(g.Webhooks, webhook)
default:
return d.Errf("unknown directive: %s", d.Val())
}
}
}
return nil
}
// Cleanup implements caddy.CleanerUpper.
// Called when Caddy config is reloaded to stop goroutines and release resources.
func (g *SIPGuardian) Cleanup() error {
g.logger.Info("SIP Guardian cleanup starting")
// Signal all goroutines to stop
close(g.stopCh)
// Stop DNS whitelist background refresh
if g.dnsWhitelist != nil {
g.dnsWhitelist.Stop()
}
// Wait for goroutines to finish (with timeout)
done := make(chan struct{})
go func() {
g.wg.Wait()
close(done)
}()
select {
case <-done:
g.logger.Debug("All goroutines stopped cleanly")
case <-time.After(5 * time.Second):
g.logger.Warn("Timeout waiting for goroutines to stop")
}
// Close storage connection
if g.storage != nil {
if err := g.storage.Close(); err != nil {
g.logger.Error("Error closing storage", zap.Error(err))
}
}
// Close GeoIP database
if g.geoIP != nil {
g.geoIP.Close()
}
g.logger.Info("SIP Guardian cleanup complete")
return nil
}
// Validate implements caddy.Validator.
// Called after Provision() to validate configuration before use.
func (g *SIPGuardian) Validate() error {
if g.MaxFailures < 1 {
return fmt.Errorf("max_failures must be at least 1, got %d", g.MaxFailures)
}
if g.MaxFailures > 1000 {
return fmt.Errorf("max_failures exceeds reasonable limit (1000), got %d", g.MaxFailures)
}
if time.Duration(g.FindTime) < time.Second {
return fmt.Errorf("find_time must be at least 1s, got %v", time.Duration(g.FindTime))
}
if time.Duration(g.FindTime) > 24*time.Hour {
return fmt.Errorf("find_time exceeds reasonable limit (24h), got %v", time.Duration(g.FindTime))
}
if time.Duration(g.BanTime) < time.Second {
return fmt.Errorf("ban_time must be at least 1s, got %v", time.Duration(g.BanTime))
}
// Validate conflicting country configurations
if len(g.AllowedCountries) > 0 && len(g.BlockedCountries) > 0 {
return fmt.Errorf("cannot specify both allowed_countries and blocked_countries")
}
// Validate DNS refresh interval
if g.DNSRefresh > 0 && time.Duration(g.DNSRefresh) < 30*time.Second {
return fmt.Errorf("dns_refresh must be at least 30s for stability, got %v", time.Duration(g.DNSRefresh))
}
return nil
}
// BanIP manually adds an IP to the ban list with a reason.
// This is the public API for external callers (like AdminHandler).
func (g *SIPGuardian) BanIP(ip, reason string) {
if g.IsWhitelisted(ip) {
g.logger.Info("Attempted to ban whitelisted IP", zap.String("ip", ip))
return
}
g.mu.Lock()
defer g.mu.Unlock()
// Create a failure tracker if needed (for hit count)
if _, exists := g.failureCounts[ip]; !exists {
g.failureCounts[ip] = &failureTracker{
count: g.MaxFailures,
firstSeen: time.Now(),
lastSeen: time.Now(),
}
} else {
g.failureCounts[ip].count = g.MaxFailures
}
g.banIP(ip, reason)
}
// Interface guards
var (
_ caddy.Module = (*SIPGuardian)(nil)
_ caddy.Provisioner = (*SIPGuardian)(nil)
_ caddy.CleanerUpper = (*SIPGuardian)(nil)
_ caddy.Validator = (*SIPGuardian)(nil)
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
)