caddy-sip-guardian/storage.go
Ryan Malloy c73fa9d3d1 Add extension enumeration detection and comprehensive SIP protection
Major features:
- Extension enumeration detection with 3 detection algorithms:
  - Max unique extensions threshold (default: 20 in 5 min)
  - Sequential pattern detection (e.g., 100,101,102...)
  - Rapid-fire detection (many extensions in short window)
- Prometheus metrics for all SIP Guardian operations
- SQLite persistent storage for bans and attack history
- Webhook notifications for ban/unban/suspicious events
- GeoIP-based country blocking with continent shortcuts
- Per-method rate limiting with token bucket algorithm

Bug fixes:
- Fix whitelist count always reporting zero in stats
- Fix whitelisted connections metric never incrementing
- Fix Caddyfile config not being applied to shared guardian

New files:
- enumeration.go: Extension enumeration detector
- enumeration_test.go: 14 comprehensive unit tests
- metrics.go: Prometheus metrics handler
- storage.go: SQLite persistence layer
- webhooks.go: Webhook notification system
- geoip.go: MaxMind GeoIP integration
- ratelimit.go: Per-method rate limiting

Testing:
- sandbox/ contains complete Docker Compose test environment
- All 14 enumeration tests pass
2025-12-07 15:22:28 -07:00

609 lines
15 KiB
Go

package sipguardian
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"go.uber.org/zap"
_ "modernc.org/sqlite" // Pure Go SQLite driver
)
// StorageConfig holds persistent storage configuration
type StorageConfig struct {
// Path to the SQLite database file
Path string `json:"path,omitempty"`
// SyncInterval for periodic state sync
SyncInterval time.Duration `json:"sync_interval,omitempty"`
// RetainExpired keeps expired bans for analysis (default: 7 days)
RetainExpired time.Duration `json:"retain_expired,omitempty"`
}
// Storage provides persistent storage for SIP Guardian state
type Storage struct {
db *sql.DB
logger *zap.Logger
config StorageConfig
mu sync.Mutex
done chan struct{}
}
// Global storage instance
var (
globalStorage *Storage
storageMu sync.Mutex
)
// GetStorage returns the global storage instance
func GetStorage() *Storage {
storageMu.Lock()
defer storageMu.Unlock()
return globalStorage
}
// InitStorage initializes the persistent storage
func InitStorage(logger *zap.Logger, config StorageConfig) (*Storage, error) {
storageMu.Lock()
defer storageMu.Unlock()
if globalStorage != nil {
return globalStorage, nil
}
// Set defaults
if config.Path == "" {
// Default to data directory
dataDir := os.Getenv("XDG_DATA_HOME")
if dataDir == "" {
homeDir, _ := os.UserHomeDir()
dataDir = filepath.Join(homeDir, ".local", "share")
}
config.Path = filepath.Join(dataDir, "sip-guardian", "guardian.db")
}
if config.SyncInterval == 0 {
config.SyncInterval = 30 * time.Second
}
if config.RetainExpired == 0 {
config.RetainExpired = 7 * 24 * time.Hour // 7 days
}
// Ensure directory exists
dir := filepath.Dir(config.Path)
if err := os.MkdirAll(dir, 0755); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
// Open database
db, err := sql.Open("sqlite", config.Path)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
// Enable WAL mode for better concurrent access
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
db.Close()
return nil, fmt.Errorf("failed to set WAL mode: %w", err)
}
storage := &Storage{
db: db,
logger: logger,
config: config,
done: make(chan struct{}),
}
// Initialize schema
if err := storage.initSchema(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to initialize schema: %w", err)
}
globalStorage = storage
logger.Info("Storage initialized",
zap.String("path", config.Path),
zap.Duration("sync_interval", config.SyncInterval),
)
return storage, nil
}
// initSchema creates the database tables
func (s *Storage) initSchema() error {
schema := `
CREATE TABLE IF NOT EXISTS bans (
ip TEXT PRIMARY KEY,
reason TEXT NOT NULL,
banned_at DATETIME NOT NULL,
expires_at DATETIME NOT NULL,
hit_count INTEGER DEFAULT 0,
unbanned_at DATETIME,
unban_reason TEXT,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_bans_expires ON bans(expires_at);
CREATE INDEX IF NOT EXISTS idx_bans_banned_at ON bans(banned_at);
CREATE TABLE IF NOT EXISTS failures (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
reason TEXT NOT NULL,
recorded_at DATETIME NOT NULL,
metadata TEXT
);
CREATE INDEX IF NOT EXISTS idx_failures_ip ON failures(ip);
CREATE INDEX IF NOT EXISTS idx_failures_recorded_at ON failures(recorded_at);
CREATE TABLE IF NOT EXISTS stats (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
updated_at DATETIME NOT NULL
);
CREATE TABLE IF NOT EXISTS suspicious_patterns (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
pattern TEXT NOT NULL,
sample TEXT,
detected_at DATETIME NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_suspicious_ip ON suspicious_patterns(ip);
CREATE INDEX IF NOT EXISTS idx_suspicious_pattern ON suspicious_patterns(pattern);
CREATE TABLE IF NOT EXISTS enumeration_attempts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ip TEXT NOT NULL,
reason TEXT NOT NULL,
unique_count INTEGER NOT NULL,
extensions TEXT,
detected_at DATETIME NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_enum_ip ON enumeration_attempts(ip);
CREATE INDEX IF NOT EXISTS idx_enum_detected ON enumeration_attempts(detected_at);
`
_, err := s.db.Exec(schema)
return err
}
// SaveBan persists a ban entry
func (s *Storage) SaveBan(entry *BanEntry) error {
s.mu.Lock()
defer s.mu.Unlock()
_, err := s.db.Exec(`
INSERT OR REPLACE INTO bans (ip, reason, banned_at, expires_at, hit_count)
VALUES (?, ?, ?, ?, ?)
`, entry.IP, entry.Reason, entry.BannedAt, entry.ExpiresAt, entry.HitCount)
if err != nil {
s.logger.Error("Failed to save ban", zap.Error(err), zap.String("ip", entry.IP))
return err
}
s.logger.Debug("Ban saved to storage", zap.String("ip", entry.IP))
return nil
}
// RemoveBan marks a ban as unbanned (keeps for history)
func (s *Storage) RemoveBan(ip, reason string) error {
s.mu.Lock()
defer s.mu.Unlock()
_, err := s.db.Exec(`
UPDATE bans SET unbanned_at = ?, unban_reason = ? WHERE ip = ? AND unbanned_at IS NULL
`, time.Now().UTC(), reason, ip)
if err != nil {
s.logger.Error("Failed to remove ban", zap.Error(err), zap.String("ip", ip))
return err
}
s.logger.Debug("Ban removed from storage", zap.String("ip", ip))
return nil
}
// LoadActiveBans returns all currently active bans
func (s *Storage) LoadActiveBans() ([]BanEntry, error) {
s.mu.Lock()
defer s.mu.Unlock()
rows, err := s.db.Query(`
SELECT ip, reason, banned_at, expires_at, hit_count
FROM bans
WHERE expires_at > ? AND unbanned_at IS NULL
`, time.Now().UTC())
if err != nil {
return nil, fmt.Errorf("failed to query active bans: %w", err)
}
defer rows.Close()
var bans []BanEntry
for rows.Next() {
var entry BanEntry
if err := rows.Scan(&entry.IP, &entry.Reason, &entry.BannedAt, &entry.ExpiresAt, &entry.HitCount); err != nil {
s.logger.Error("Failed to scan ban row", zap.Error(err))
continue
}
bans = append(bans, entry)
}
return bans, rows.Err()
}
// RecordFailure records a failure event for historical analysis
func (s *Storage) RecordFailure(ip, reason string, metadata map[string]interface{}) error {
s.mu.Lock()
defer s.mu.Unlock()
var metadataJSON []byte
if metadata != nil {
var err error
metadataJSON, err = json.Marshal(metadata)
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
}
_, err := s.db.Exec(`
INSERT INTO failures (ip, reason, recorded_at, metadata)
VALUES (?, ?, ?, ?)
`, ip, reason, time.Now().UTC(), metadataJSON)
return err
}
// RecordSuspiciousPattern records a suspicious pattern detection
func (s *Storage) RecordSuspiciousPattern(ip, pattern, sample string) error {
s.mu.Lock()
defer s.mu.Unlock()
_, err := s.db.Exec(`
INSERT INTO suspicious_patterns (ip, pattern, sample, detected_at)
VALUES (?, ?, ?, ?)
`, ip, pattern, sample, time.Now().UTC())
return err
}
// RecordEnumerationAttempt records an enumeration attack detection
func (s *Storage) RecordEnumerationAttempt(ip, reason string, uniqueCount int, extensions []string) error {
s.mu.Lock()
defer s.mu.Unlock()
var extensionsJSON []byte
if extensions != nil {
var err error
extensionsJSON, err = json.Marshal(extensions)
if err != nil {
return fmt.Errorf("failed to marshal extensions: %w", err)
}
}
_, err := s.db.Exec(`
INSERT INTO enumeration_attempts (ip, reason, unique_count, extensions, detected_at)
VALUES (?, ?, ?, ?, ?)
`, ip, reason, uniqueCount, extensionsJSON, time.Now().UTC())
if err != nil {
s.logger.Error("Failed to record enumeration attempt",
zap.Error(err),
zap.String("ip", ip),
zap.String("reason", reason),
)
return err
}
s.logger.Debug("Enumeration attempt recorded",
zap.String("ip", ip),
zap.String("reason", reason),
zap.Int("unique_extensions", uniqueCount),
)
return nil
}
// GetEnumerationStats returns statistics on enumeration attempts
func (s *Storage) GetEnumerationStats(since time.Duration) ([]map[string]interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
rows, err := s.db.Query(`
SELECT reason, COUNT(*) as count, COUNT(DISTINCT ip) as unique_ips, AVG(unique_count) as avg_extensions
FROM enumeration_attempts
WHERE detected_at > ?
GROUP BY reason
ORDER BY count DESC
`, time.Now().Add(-since).UTC())
if err != nil {
return nil, fmt.Errorf("failed to query enumeration stats: %w", err)
}
defer rows.Close()
var stats []map[string]interface{}
for rows.Next() {
var reason string
var count, uniqueIPs int
var avgExtensions float64
if err := rows.Scan(&reason, &count, &uniqueIPs, &avgExtensions); err != nil {
continue
}
stats = append(stats, map[string]interface{}{
"reason": reason,
"count": count,
"unique_ips": uniqueIPs,
"avg_extensions": avgExtensions,
})
}
return stats, rows.Err()
}
// GetRecentEnumerationAttempts returns recent enumeration attempts
func (s *Storage) GetRecentEnumerationAttempts(since time.Duration, limit int) ([]map[string]interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
if limit == 0 {
limit = 100
}
rows, err := s.db.Query(`
SELECT ip, reason, unique_count, extensions, detected_at
FROM enumeration_attempts
WHERE detected_at > ?
ORDER BY detected_at DESC
LIMIT ?
`, time.Now().Add(-since).UTC(), limit)
if err != nil {
return nil, fmt.Errorf("failed to query recent enumeration attempts: %w", err)
}
defer rows.Close()
var attempts []map[string]interface{}
for rows.Next() {
var ip, reason string
var uniqueCount int
var extensionsJSON sql.NullString
var detectedAt time.Time
if err := rows.Scan(&ip, &reason, &uniqueCount, &extensionsJSON, &detectedAt); err != nil {
continue
}
entry := map[string]interface{}{
"ip": ip,
"reason": reason,
"unique_count": uniqueCount,
"detected_at": detectedAt,
}
if extensionsJSON.Valid {
var extensions []string
if err := json.Unmarshal([]byte(extensionsJSON.String), &extensions); err == nil {
entry["extensions"] = extensions
}
}
attempts = append(attempts, entry)
}
return attempts, rows.Err()
}
// GetBanHistory returns ban history for an IP
func (s *Storage) GetBanHistory(ip string) ([]BanEntry, error) {
s.mu.Lock()
defer s.mu.Unlock()
rows, err := s.db.Query(`
SELECT ip, reason, banned_at, expires_at, hit_count
FROM bans
WHERE ip = ?
ORDER BY banned_at DESC
`, ip)
if err != nil {
return nil, fmt.Errorf("failed to query ban history: %w", err)
}
defer rows.Close()
var bans []BanEntry
for rows.Next() {
var entry BanEntry
if err := rows.Scan(&entry.IP, &entry.Reason, &entry.BannedAt, &entry.ExpiresAt, &entry.HitCount); err != nil {
continue
}
bans = append(bans, entry)
}
return bans, rows.Err()
}
// GetRecentFailures returns recent failures (for analysis)
func (s *Storage) GetRecentFailures(since time.Duration, limit int) ([]map[string]interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
if limit == 0 {
limit = 100
}
rows, err := s.db.Query(`
SELECT ip, reason, recorded_at, metadata
FROM failures
WHERE recorded_at > ?
ORDER BY recorded_at DESC
LIMIT ?
`, time.Now().Add(-since).UTC(), limit)
if err != nil {
return nil, fmt.Errorf("failed to query recent failures: %w", err)
}
defer rows.Close()
var failures []map[string]interface{}
for rows.Next() {
var ip, reason string
var recordedAt time.Time
var metadataJSON sql.NullString
if err := rows.Scan(&ip, &reason, &recordedAt, &metadataJSON); err != nil {
continue
}
entry := map[string]interface{}{
"ip": ip,
"reason": reason,
"recorded_at": recordedAt,
}
if metadataJSON.Valid {
var metadata map[string]interface{}
if err := json.Unmarshal([]byte(metadataJSON.String), &metadata); err == nil {
entry["metadata"] = metadata
}
}
failures = append(failures, entry)
}
return failures, rows.Err()
}
// GetTopOffenders returns IPs with most failures/bans
func (s *Storage) GetTopOffenders(since time.Duration, limit int) ([]map[string]interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
if limit == 0 {
limit = 10
}
rows, err := s.db.Query(`
SELECT ip, COUNT(*) as count, MAX(recorded_at) as last_seen
FROM failures
WHERE recorded_at > ?
GROUP BY ip
ORDER BY count DESC
LIMIT ?
`, time.Now().Add(-since).UTC(), limit)
if err != nil {
return nil, fmt.Errorf("failed to query top offenders: %w", err)
}
defer rows.Close()
var offenders []map[string]interface{}
for rows.Next() {
var ip string
var count int
var lastSeen time.Time
if err := rows.Scan(&ip, &count, &lastSeen); err != nil {
continue
}
offenders = append(offenders, map[string]interface{}{
"ip": ip,
"count": count,
"last_seen": lastSeen,
})
}
return offenders, rows.Err()
}
// GetPatternStats returns statistics on detected patterns
func (s *Storage) GetPatternStats(since time.Duration) ([]map[string]interface{}, error) {
s.mu.Lock()
defer s.mu.Unlock()
rows, err := s.db.Query(`
SELECT pattern, COUNT(*) as count, COUNT(DISTINCT ip) as unique_ips
FROM suspicious_patterns
WHERE detected_at > ?
GROUP BY pattern
ORDER BY count DESC
`, time.Now().Add(-since).UTC())
if err != nil {
return nil, fmt.Errorf("failed to query pattern stats: %w", err)
}
defer rows.Close()
var stats []map[string]interface{}
for rows.Next() {
var pattern string
var count, uniqueIPs int
if err := rows.Scan(&pattern, &count, &uniqueIPs); err != nil {
continue
}
stats = append(stats, map[string]interface{}{
"pattern": pattern,
"count": count,
"unique_ips": uniqueIPs,
})
}
return stats, rows.Err()
}
// Cleanup removes old data
func (s *Storage) Cleanup() error {
s.mu.Lock()
defer s.mu.Unlock()
cutoff := time.Now().Add(-s.config.RetainExpired).UTC()
// Remove old unbanned entries
_, err := s.db.Exec(`
DELETE FROM bans WHERE unbanned_at IS NOT NULL AND unbanned_at < ?
`, cutoff)
if err != nil {
return fmt.Errorf("failed to cleanup old bans: %w", err)
}
// Remove old failures
_, err = s.db.Exec(`
DELETE FROM failures WHERE recorded_at < ?
`, cutoff)
if err != nil {
return fmt.Errorf("failed to cleanup old failures: %w", err)
}
// Remove old suspicious patterns
_, err = s.db.Exec(`
DELETE FROM suspicious_patterns WHERE detected_at < ?
`, cutoff)
if err != nil {
return fmt.Errorf("failed to cleanup old patterns: %w", err)
}
// Remove old enumeration attempts
_, err = s.db.Exec(`
DELETE FROM enumeration_attempts WHERE detected_at < ?
`, cutoff)
if err != nil {
return fmt.Errorf("failed to cleanup old enumeration attempts: %w", err)
}
s.logger.Debug("Storage cleanup completed", zap.Time("cutoff", cutoff))
return nil
}
// Close closes the database connection
func (s *Storage) Close() error {
close(s.done)
return s.db.Close()
}