caddy-sip-guardian/sipguardian.go
Ryan Malloy c73fa9d3d1 Add extension enumeration detection and comprehensive SIP protection
Major features:
- Extension enumeration detection with 3 detection algorithms:
  - Max unique extensions threshold (default: 20 in 5 min)
  - Sequential pattern detection (e.g., 100,101,102...)
  - Rapid-fire detection (many extensions in short window)
- Prometheus metrics for all SIP Guardian operations
- SQLite persistent storage for bans and attack history
- Webhook notifications for ban/unban/suspicious events
- GeoIP-based country blocking with continent shortcuts
- Per-method rate limiting with token bucket algorithm

Bug fixes:
- Fix whitelist count always reporting zero in stats
- Fix whitelisted connections metric never incrementing
- Fix Caddyfile config not being applied to shared guardian

New files:
- enumeration.go: Extension enumeration detector
- enumeration_test.go: 14 comprehensive unit tests
- metrics.go: Prometheus metrics handler
- storage.go: SQLite persistence layer
- webhooks.go: Webhook notification system
- geoip.go: MaxMind GeoIP integration
- ratelimit.go: Per-method rate limiting

Testing:
- sandbox/ contains complete Docker Compose test environment
- All 14 enumeration tests pass
2025-12-07 15:22:28 -07:00

633 lines
14 KiB
Go

// Package sipguardian provides a Caddy module for SIP-aware rate limiting and IP banning.
// It integrates with caddy-l4 for Layer 4 proxying and caddy-ratelimit for rate limiting.
package sipguardian
import (
"fmt"
"net"
"sync"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"go.uber.org/zap"
)
// Feature flags for optional components
var (
enableMetrics = true
enableWebhooks = true
enableStorage = true
)
func init() {
caddy.RegisterModule(SIPGuardian{})
}
// BanEntry represents a banned IP with metadata
type BanEntry struct {
IP string `json:"ip"`
Reason string `json:"reason"`
BannedAt time.Time `json:"banned_at"`
ExpiresAt time.Time `json:"expires_at"`
HitCount int `json:"hit_count"`
}
// SIPGuardian implements intelligent SIP protection at Layer 4
type SIPGuardian struct {
// Configuration
MaxFailures int `json:"max_failures,omitempty"`
FindTime caddy.Duration `json:"find_time,omitempty"`
BanTime caddy.Duration `json:"ban_time,omitempty"`
WhitelistCIDR []string `json:"whitelist_cidr,omitempty"`
// Webhook configuration
Webhooks []WebhookConfig `json:"webhooks,omitempty"`
// Storage configuration
StoragePath string `json:"storage_path,omitempty"`
// GeoIP configuration
GeoIPPath string `json:"geoip_path,omitempty"`
BlockedCountries []string `json:"blocked_countries,omitempty"`
AllowedCountries []string `json:"allowed_countries,omitempty"`
// Enumeration detection configuration
Enumeration *EnumerationConfig `json:"enumeration,omitempty"`
// Runtime state
logger *zap.Logger
bannedIPs map[string]*BanEntry
failureCounts map[string]*failureTracker
whitelistNets []*net.IPNet
mu sync.RWMutex
storage *Storage
geoIP *GeoIPLookup
}
type failureTracker struct {
count int
firstSeen time.Time
lastSeen time.Time
}
// CaddyModule returns the Caddy module information.
func (SIPGuardian) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "sip_guardian",
New: func() caddy.Module { return new(SIPGuardian) },
}
}
// Provision sets up the module.
func (g *SIPGuardian) Provision(ctx caddy.Context) error {
g.logger = ctx.Logger()
g.bannedIPs = make(map[string]*BanEntry)
g.failureCounts = make(map[string]*failureTracker)
// Set defaults
if g.MaxFailures == 0 {
g.MaxFailures = 5
}
if g.FindTime == 0 {
g.FindTime = caddy.Duration(10 * time.Minute)
}
if g.BanTime == 0 {
g.BanTime = caddy.Duration(1 * time.Hour)
}
// Parse whitelist CIDRs
for _, cidr := range g.WhitelistCIDR {
_, network, err := net.ParseCIDR(cidr)
if err != nil {
return fmt.Errorf("invalid whitelist CIDR %s: %v", cidr, err)
}
g.whitelistNets = append(g.whitelistNets, network)
}
// Initialize metrics
if enableMetrics {
RegisterMetrics()
}
// Initialize webhooks
if enableWebhooks && len(g.Webhooks) > 0 {
wm := GetWebhookManager(g.logger)
for _, config := range g.Webhooks {
wm.AddWebhook(config)
}
}
// Initialize persistent storage
if enableStorage && g.StoragePath != "" {
storage, err := InitStorage(g.logger, StorageConfig{
Path: g.StoragePath,
})
if err != nil {
g.logger.Warn("Failed to initialize storage, continuing without persistence",
zap.Error(err),
)
} else {
g.storage = storage
// Load existing bans from storage
if err := g.loadBansFromStorage(); err != nil {
g.logger.Warn("Failed to load bans from storage", zap.Error(err))
}
}
}
// Initialize GeoIP if configured
if g.GeoIPPath != "" {
geoIP, err := NewGeoIPLookup(g.GeoIPPath)
if err != nil {
g.logger.Warn("Failed to initialize GeoIP, country blocking disabled",
zap.Error(err),
)
} else {
g.geoIP = geoIP
g.logger.Info("GeoIP initialized",
zap.Int("blocked_countries", len(g.BlockedCountries)),
zap.Int("allowed_countries", len(g.AllowedCountries)),
)
}
}
// Initialize enumeration detection with config if specified
if g.Enumeration != nil {
SetEnumerationConfig(*g.Enumeration)
g.logger.Info("Enumeration detection configured",
zap.Int("max_extensions", g.Enumeration.MaxExtensions),
zap.Int("sequential_threshold", g.Enumeration.SequentialThreshold),
zap.Duration("extension_window", g.Enumeration.ExtensionWindow),
)
}
// Start cleanup goroutine
go g.cleanupLoop(ctx)
g.logger.Info("SIP Guardian initialized",
zap.Int("max_failures", g.MaxFailures),
zap.Duration("find_time", time.Duration(g.FindTime)),
zap.Duration("ban_time", time.Duration(g.BanTime)),
zap.Int("whitelist_count", len(g.whitelistNets)),
zap.Bool("storage_enabled", g.storage != nil),
zap.Bool("geoip_enabled", g.geoIP != nil),
zap.Int("webhook_count", len(g.Webhooks)),
zap.Bool("enumeration_enabled", g.Enumeration != nil),
)
return nil
}
// loadBansFromStorage loads active bans from persistent storage
func (g *SIPGuardian) loadBansFromStorage() error {
if g.storage == nil {
return nil
}
bans, err := g.storage.LoadActiveBans()
if err != nil {
return err
}
g.mu.Lock()
defer g.mu.Unlock()
for _, ban := range bans {
entry := ban // Create a copy
g.bannedIPs[entry.IP] = &entry
}
g.logger.Info("Loaded bans from storage", zap.Int("count", len(bans)))
return nil
}
// IsWhitelisted checks if an IP is in the whitelist
func (g *SIPGuardian) IsWhitelisted(ip string) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
for _, network := range g.whitelistNets {
if network.Contains(parsedIP) {
if enableMetrics {
RecordWhitelistedConnection()
}
return true
}
}
return false
}
// IsCountryBlocked checks if an IP's country is blocked (or not in allowed list)
func (g *SIPGuardian) IsCountryBlocked(ip string) (bool, string) {
if g.geoIP == nil {
return false, ""
}
country, err := g.geoIP.LookupCountry(ip)
if err != nil {
g.logger.Debug("GeoIP lookup failed", zap.String("ip", ip), zap.Error(err))
return false, ""
}
// If allowed countries are specified, only those are allowed
if len(g.AllowedCountries) > 0 {
for _, allowed := range g.AllowedCountries {
if country == allowed {
return false, country
}
}
return true, country // Not in allowed list
}
// Check blocked countries
for _, blocked := range g.BlockedCountries {
if country == blocked {
return true, country
}
}
return false, country
}
// IsBanned checks if an IP is currently banned
func (g *SIPGuardian) IsBanned(ip string) bool {
g.mu.RLock()
defer g.mu.RUnlock()
if entry, exists := g.bannedIPs[ip]; exists {
if time.Now().Before(entry.ExpiresAt) {
return true
}
}
return false
}
// RecordFailure records a failed authentication attempt
func (g *SIPGuardian) RecordFailure(ip, reason string) bool {
if g.IsWhitelisted(ip) {
return false
}
g.mu.Lock()
defer g.mu.Unlock()
now := time.Now()
findWindow := time.Duration(g.FindTime)
tracker, exists := g.failureCounts[ip]
if !exists || now.Sub(tracker.firstSeen) > findWindow {
// Start new tracking window
tracker = &failureTracker{
count: 1,
firstSeen: now,
lastSeen: now,
}
g.failureCounts[ip] = tracker
} else {
tracker.count++
tracker.lastSeen = now
}
g.logger.Debug("Failure recorded",
zap.String("ip", ip),
zap.String("reason", reason),
zap.Int("count", tracker.count),
)
// Record metrics
if enableMetrics {
RecordFailure(reason)
UpdateTrackedIPs(len(g.failureCounts))
}
// Record in storage (async)
if g.storage != nil {
go func() {
g.storage.RecordFailure(ip, reason, nil)
}()
}
// Emit failure event via webhook
if enableWebhooks {
EmitFailureEvent(g.logger, ip, reason, tracker.count)
}
// Check if we should ban
if tracker.count >= g.MaxFailures {
g.banIP(ip, reason)
return true
}
return false
}
// banIP adds an IP to the ban list (must hold lock)
func (g *SIPGuardian) banIP(ip, reason string) {
now := time.Now()
banDuration := time.Duration(g.BanTime)
hitCount := 0
if tracker := g.failureCounts[ip]; tracker != nil {
hitCount = tracker.count
}
entry := &BanEntry{
IP: ip,
Reason: reason,
BannedAt: now,
ExpiresAt: now.Add(banDuration),
HitCount: hitCount,
}
g.bannedIPs[ip] = entry
// Clear failure counter
delete(g.failureCounts, ip)
g.logger.Warn("IP banned",
zap.String("ip", ip),
zap.String("reason", reason),
zap.Duration("duration", banDuration),
)
// Record metrics
if enableMetrics {
RecordBan()
}
// Save to persistent storage
if g.storage != nil {
go func() {
if err := g.storage.SaveBan(entry); err != nil {
g.logger.Error("Failed to save ban to storage", zap.Error(err))
}
}()
}
// Emit webhook event
if enableWebhooks {
EmitBanEvent(g.logger, entry)
}
}
// UnbanIP manually removes an IP from the ban list
func (g *SIPGuardian) UnbanIP(ip string) bool {
g.mu.Lock()
defer g.mu.Unlock()
if entry, exists := g.bannedIPs[ip]; exists {
// Record ban duration for metrics
if enableMetrics {
duration := time.Since(entry.BannedAt).Seconds()
RecordBanDuration(duration)
RecordUnban()
}
delete(g.bannedIPs, ip)
g.logger.Info("IP unbanned", zap.String("ip", ip))
// Update storage
if g.storage != nil {
go func() {
if err := g.storage.RemoveBan(ip, "manual_unban"); err != nil {
g.logger.Error("Failed to update storage on unban", zap.Error(err))
}
}()
}
// Emit webhook event
if enableWebhooks {
EmitUnbanEvent(g.logger, ip, "manual_unban")
}
return true
}
return false
}
// GetBannedIPs returns a list of currently banned IPs
func (g *SIPGuardian) GetBannedIPs() []BanEntry {
g.mu.RLock()
defer g.mu.RUnlock()
var entries []BanEntry
now := time.Now()
for _, entry := range g.bannedIPs {
if now.Before(entry.ExpiresAt) {
entries = append(entries, *entry)
}
}
return entries
}
// GetStats returns current statistics
func (g *SIPGuardian) GetStats() map[string]interface{} {
g.mu.RLock()
defer g.mu.RUnlock()
activeBans := 0
now := time.Now()
for _, entry := range g.bannedIPs {
if now.Before(entry.ExpiresAt) {
activeBans++
}
}
return map[string]interface{}{
"active_bans": activeBans,
"tracked_failures": len(g.failureCounts),
"whitelist_count": len(g.whitelistNets),
}
}
// cleanupLoop periodically removes expired entries
func (g *SIPGuardian) cleanupLoop(ctx caddy.Context) {
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
g.cleanup()
}
}
}
func (g *SIPGuardian) cleanup() {
g.mu.Lock()
defer g.mu.Unlock()
now := time.Now()
findWindow := time.Duration(g.FindTime)
// Cleanup expired bans
for ip, entry := range g.bannedIPs {
if now.After(entry.ExpiresAt) {
delete(g.bannedIPs, ip)
g.logger.Debug("Ban expired", zap.String("ip", ip))
}
}
// Cleanup old failure trackers
for ip, tracker := range g.failureCounts {
if now.Sub(tracker.firstSeen) > findWindow {
delete(g.failureCounts, ip)
}
}
}
// UnmarshalCaddyfile implements caddyfile.Unmarshaler.
// Extended configuration options:
//
// sip_guardian {
// max_failures 5
// find_time 10m
// ban_time 1h
// whitelist 10.0.0.0/8 192.168.0.0/16
//
// # Persistent storage
// storage /var/lib/sip-guardian/guardian.db
//
// # GeoIP blocking (requires MaxMind database)
// geoip_db /path/to/GeoLite2-Country.mmdb
// block_countries CN RU KP
// allow_countries US CA GB # Alternative: only allow these
//
// # Webhook notifications
// webhook https://example.com/hook {
// events ban unban suspicious
// secret my-webhook-secret
// timeout 10s
// }
// }
func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
for d.Next() {
for d.NextBlock(0) {
switch d.Val() {
case "max_failures":
if !d.NextArg() {
return d.ArgErr()
}
var val int
if _, err := fmt.Sscanf(d.Val(), "%d", &val); err != nil {
return d.Errf("invalid max_failures: %v", err)
}
g.MaxFailures = val
case "find_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid find_time: %v", err)
}
g.FindTime = caddy.Duration(dur)
case "ban_time":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid ban_time: %v", err)
}
g.BanTime = caddy.Duration(dur)
case "whitelist":
for d.NextArg() {
g.WhitelistCIDR = append(g.WhitelistCIDR, d.Val())
}
case "storage":
if !d.NextArg() {
return d.ArgErr()
}
g.StoragePath = d.Val()
case "geoip_db":
if !d.NextArg() {
return d.ArgErr()
}
g.GeoIPPath = d.Val()
case "block_countries":
for d.NextArg() {
country := d.Val()
// Support continent expansion (e.g., "AS" for all of Asia)
if expanded := ExpandContinentCode(country); expanded != nil {
g.BlockedCountries = append(g.BlockedCountries, expanded...)
} else {
g.BlockedCountries = append(g.BlockedCountries, country)
}
}
case "allow_countries":
for d.NextArg() {
country := d.Val()
if expanded := ExpandContinentCode(country); expanded != nil {
g.AllowedCountries = append(g.AllowedCountries, expanded...)
} else {
g.AllowedCountries = append(g.AllowedCountries, country)
}
}
case "webhook":
if !d.NextArg() {
return d.ArgErr()
}
webhook := WebhookConfig{
URL: d.Val(),
}
// Parse webhook block if present
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "events":
webhook.Events = d.RemainingArgs()
case "secret":
if !d.NextArg() {
return d.ArgErr()
}
webhook.Secret = d.Val()
case "timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid webhook timeout: %v", err)
}
webhook.Timeout = dur
case "header":
args := d.RemainingArgs()
if len(args) != 2 {
return d.Errf("header requires name and value")
}
if webhook.Headers == nil {
webhook.Headers = make(map[string]string)
}
webhook.Headers[args[0]] = args[1]
default:
return d.Errf("unknown webhook directive: %s", d.Val())
}
}
g.Webhooks = append(g.Webhooks, webhook)
default:
return d.Errf("unknown directive: %s", d.Val())
}
}
}
return nil
}
// Interface guards
var (
_ caddy.Module = (*SIPGuardian)(nil)
_ caddy.Provisioner = (*SIPGuardian)(nil)
_ caddyfile.Unmarshaler = (*SIPGuardian)(nil)
)