package sipguardian import ( "database/sql" "encoding/json" "fmt" "os" "path/filepath" "sync" "time" "go.uber.org/zap" _ "modernc.org/sqlite" // Pure Go SQLite driver ) // StorageConfig holds persistent storage configuration type StorageConfig struct { // Path to the SQLite database file Path string `json:"path,omitempty"` // SyncInterval for periodic state sync SyncInterval time.Duration `json:"sync_interval,omitempty"` // RetainExpired keeps expired bans for analysis (default: 7 days) RetainExpired time.Duration `json:"retain_expired,omitempty"` } // Storage provides persistent storage for SIP Guardian state type Storage struct { db *sql.DB logger *zap.Logger config StorageConfig mu sync.Mutex done chan struct{} } // Global storage instance var ( globalStorage *Storage storageMu sync.Mutex ) // GetStorage returns the global storage instance func GetStorage() *Storage { storageMu.Lock() defer storageMu.Unlock() return globalStorage } // InitStorage initializes the persistent storage func InitStorage(logger *zap.Logger, config StorageConfig) (*Storage, error) { storageMu.Lock() defer storageMu.Unlock() if globalStorage != nil { return globalStorage, nil } // Set defaults if config.Path == "" { // Default to data directory dataDir := os.Getenv("XDG_DATA_HOME") if dataDir == "" { homeDir, _ := os.UserHomeDir() dataDir = filepath.Join(homeDir, ".local", "share") } config.Path = filepath.Join(dataDir, "sip-guardian", "guardian.db") } if config.SyncInterval == 0 { config.SyncInterval = 30 * time.Second } if config.RetainExpired == 0 { config.RetainExpired = 7 * 24 * time.Hour // 7 days } // Ensure directory exists dir := filepath.Dir(config.Path) if err := os.MkdirAll(dir, 0755); err != nil { return nil, fmt.Errorf("failed to create storage directory: %w", err) } // Open database db, err := sql.Open("sqlite", config.Path) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } // Enable WAL mode for better concurrent access if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { db.Close() return nil, fmt.Errorf("failed to set WAL mode: %w", err) } storage := &Storage{ db: db, logger: logger, config: config, done: make(chan struct{}), } // Initialize schema if err := storage.initSchema(); err != nil { db.Close() return nil, fmt.Errorf("failed to initialize schema: %w", err) } globalStorage = storage logger.Info("Storage initialized", zap.String("path", config.Path), zap.Duration("sync_interval", config.SyncInterval), ) return storage, nil } // initSchema creates the database tables func (s *Storage) initSchema() error { schema := ` CREATE TABLE IF NOT EXISTS bans ( ip TEXT PRIMARY KEY, reason TEXT NOT NULL, banned_at DATETIME NOT NULL, expires_at DATETIME NOT NULL, hit_count INTEGER DEFAULT 0, unbanned_at DATETIME, unban_reason TEXT, metadata TEXT ); CREATE INDEX IF NOT EXISTS idx_bans_expires ON bans(expires_at); CREATE INDEX IF NOT EXISTS idx_bans_banned_at ON bans(banned_at); CREATE TABLE IF NOT EXISTS failures ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, reason TEXT NOT NULL, recorded_at DATETIME NOT NULL, metadata TEXT ); CREATE INDEX IF NOT EXISTS idx_failures_ip ON failures(ip); CREATE INDEX IF NOT EXISTS idx_failures_recorded_at ON failures(recorded_at); CREATE TABLE IF NOT EXISTS stats ( key TEXT PRIMARY KEY, value TEXT NOT NULL, updated_at DATETIME NOT NULL ); CREATE TABLE IF NOT EXISTS suspicious_patterns ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, pattern TEXT NOT NULL, sample TEXT, detected_at DATETIME NOT NULL ); CREATE INDEX IF NOT EXISTS idx_suspicious_ip ON suspicious_patterns(ip); CREATE INDEX IF NOT EXISTS idx_suspicious_pattern ON suspicious_patterns(pattern); CREATE TABLE IF NOT EXISTS enumeration_attempts ( id INTEGER PRIMARY KEY AUTOINCREMENT, ip TEXT NOT NULL, reason TEXT NOT NULL, unique_count INTEGER NOT NULL, extensions TEXT, detected_at DATETIME NOT NULL ); CREATE INDEX IF NOT EXISTS idx_enum_ip ON enumeration_attempts(ip); CREATE INDEX IF NOT EXISTS idx_enum_detected ON enumeration_attempts(detected_at); ` _, err := s.db.Exec(schema) return err } // SaveBan persists a ban entry func (s *Storage) SaveBan(entry *BanEntry) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.Exec(` INSERT OR REPLACE INTO bans (ip, reason, banned_at, expires_at, hit_count) VALUES (?, ?, ?, ?, ?) `, entry.IP, entry.Reason, entry.BannedAt, entry.ExpiresAt, entry.HitCount) if err != nil { s.logger.Error("Failed to save ban", zap.Error(err), zap.String("ip", entry.IP)) return err } s.logger.Debug("Ban saved to storage", zap.String("ip", entry.IP)) return nil } // RemoveBan marks a ban as unbanned (keeps for history) func (s *Storage) RemoveBan(ip, reason string) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.Exec(` UPDATE bans SET unbanned_at = ?, unban_reason = ? WHERE ip = ? AND unbanned_at IS NULL `, time.Now().UTC(), reason, ip) if err != nil { s.logger.Error("Failed to remove ban", zap.Error(err), zap.String("ip", ip)) return err } s.logger.Debug("Ban removed from storage", zap.String("ip", ip)) return nil } // LoadActiveBans returns all currently active bans func (s *Storage) LoadActiveBans() ([]BanEntry, error) { s.mu.Lock() defer s.mu.Unlock() rows, err := s.db.Query(` SELECT ip, reason, banned_at, expires_at, hit_count FROM bans WHERE expires_at > ? AND unbanned_at IS NULL `, time.Now().UTC()) if err != nil { return nil, fmt.Errorf("failed to query active bans: %w", err) } defer rows.Close() var bans []BanEntry for rows.Next() { var entry BanEntry if err := rows.Scan(&entry.IP, &entry.Reason, &entry.BannedAt, &entry.ExpiresAt, &entry.HitCount); err != nil { s.logger.Error("Failed to scan ban row", zap.Error(err)) continue } bans = append(bans, entry) } return bans, rows.Err() } // RecordFailure records a failure event for historical analysis func (s *Storage) RecordFailure(ip, reason string, metadata map[string]interface{}) error { s.mu.Lock() defer s.mu.Unlock() var metadataJSON []byte if metadata != nil { var err error metadataJSON, err = json.Marshal(metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } } _, err := s.db.Exec(` INSERT INTO failures (ip, reason, recorded_at, metadata) VALUES (?, ?, ?, ?) `, ip, reason, time.Now().UTC(), metadataJSON) return err } // RecordSuspiciousPattern records a suspicious pattern detection func (s *Storage) RecordSuspiciousPattern(ip, pattern, sample string) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.Exec(` INSERT INTO suspicious_patterns (ip, pattern, sample, detected_at) VALUES (?, ?, ?, ?) `, ip, pattern, sample, time.Now().UTC()) return err } // RecordEnumerationAttempt records an enumeration attack detection func (s *Storage) RecordEnumerationAttempt(ip, reason string, uniqueCount int, extensions []string) error { s.mu.Lock() defer s.mu.Unlock() var extensionsJSON []byte if extensions != nil { var err error extensionsJSON, err = json.Marshal(extensions) if err != nil { return fmt.Errorf("failed to marshal extensions: %w", err) } } _, err := s.db.Exec(` INSERT INTO enumeration_attempts (ip, reason, unique_count, extensions, detected_at) VALUES (?, ?, ?, ?, ?) `, ip, reason, uniqueCount, extensionsJSON, time.Now().UTC()) if err != nil { s.logger.Error("Failed to record enumeration attempt", zap.Error(err), zap.String("ip", ip), zap.String("reason", reason), ) return err } s.logger.Debug("Enumeration attempt recorded", zap.String("ip", ip), zap.String("reason", reason), zap.Int("unique_extensions", uniqueCount), ) return nil } // GetEnumerationStats returns statistics on enumeration attempts func (s *Storage) GetEnumerationStats(since time.Duration) ([]map[string]interface{}, error) { s.mu.Lock() defer s.mu.Unlock() rows, err := s.db.Query(` SELECT reason, COUNT(*) as count, COUNT(DISTINCT ip) as unique_ips, AVG(unique_count) as avg_extensions FROM enumeration_attempts WHERE detected_at > ? GROUP BY reason ORDER BY count DESC `, time.Now().Add(-since).UTC()) if err != nil { return nil, fmt.Errorf("failed to query enumeration stats: %w", err) } defer rows.Close() var stats []map[string]interface{} for rows.Next() { var reason string var count, uniqueIPs int var avgExtensions float64 if err := rows.Scan(&reason, &count, &uniqueIPs, &avgExtensions); err != nil { continue } stats = append(stats, map[string]interface{}{ "reason": reason, "count": count, "unique_ips": uniqueIPs, "avg_extensions": avgExtensions, }) } return stats, rows.Err() } // GetRecentEnumerationAttempts returns recent enumeration attempts func (s *Storage) GetRecentEnumerationAttempts(since time.Duration, limit int) ([]map[string]interface{}, error) { s.mu.Lock() defer s.mu.Unlock() if limit == 0 { limit = 100 } rows, err := s.db.Query(` SELECT ip, reason, unique_count, extensions, detected_at FROM enumeration_attempts WHERE detected_at > ? ORDER BY detected_at DESC LIMIT ? `, time.Now().Add(-since).UTC(), limit) if err != nil { return nil, fmt.Errorf("failed to query recent enumeration attempts: %w", err) } defer rows.Close() var attempts []map[string]interface{} for rows.Next() { var ip, reason string var uniqueCount int var extensionsJSON sql.NullString var detectedAt time.Time if err := rows.Scan(&ip, &reason, &uniqueCount, &extensionsJSON, &detectedAt); err != nil { continue } entry := map[string]interface{}{ "ip": ip, "reason": reason, "unique_count": uniqueCount, "detected_at": detectedAt, } if extensionsJSON.Valid { var extensions []string if err := json.Unmarshal([]byte(extensionsJSON.String), &extensions); err == nil { entry["extensions"] = extensions } } attempts = append(attempts, entry) } return attempts, rows.Err() } // GetBanHistory returns ban history for an IP func (s *Storage) GetBanHistory(ip string) ([]BanEntry, error) { s.mu.Lock() defer s.mu.Unlock() rows, err := s.db.Query(` SELECT ip, reason, banned_at, expires_at, hit_count FROM bans WHERE ip = ? ORDER BY banned_at DESC `, ip) if err != nil { return nil, fmt.Errorf("failed to query ban history: %w", err) } defer rows.Close() var bans []BanEntry for rows.Next() { var entry BanEntry if err := rows.Scan(&entry.IP, &entry.Reason, &entry.BannedAt, &entry.ExpiresAt, &entry.HitCount); err != nil { continue } bans = append(bans, entry) } return bans, rows.Err() } // GetRecentFailures returns recent failures (for analysis) func (s *Storage) GetRecentFailures(since time.Duration, limit int) ([]map[string]interface{}, error) { s.mu.Lock() defer s.mu.Unlock() if limit == 0 { limit = 100 } rows, err := s.db.Query(` SELECT ip, reason, recorded_at, metadata FROM failures WHERE recorded_at > ? ORDER BY recorded_at DESC LIMIT ? `, time.Now().Add(-since).UTC(), limit) if err != nil { return nil, fmt.Errorf("failed to query recent failures: %w", err) } defer rows.Close() var failures []map[string]interface{} for rows.Next() { var ip, reason string var recordedAt time.Time var metadataJSON sql.NullString if err := rows.Scan(&ip, &reason, &recordedAt, &metadataJSON); err != nil { continue } entry := map[string]interface{}{ "ip": ip, "reason": reason, "recorded_at": recordedAt, } if metadataJSON.Valid { var metadata map[string]interface{} if err := json.Unmarshal([]byte(metadataJSON.String), &metadata); err == nil { entry["metadata"] = metadata } } failures = append(failures, entry) } return failures, rows.Err() } // GetTopOffenders returns IPs with most failures/bans func (s *Storage) GetTopOffenders(since time.Duration, limit int) ([]map[string]interface{}, error) { s.mu.Lock() defer s.mu.Unlock() if limit == 0 { limit = 10 } rows, err := s.db.Query(` SELECT ip, COUNT(*) as count, MAX(recorded_at) as last_seen FROM failures WHERE recorded_at > ? GROUP BY ip ORDER BY count DESC LIMIT ? `, time.Now().Add(-since).UTC(), limit) if err != nil { return nil, fmt.Errorf("failed to query top offenders: %w", err) } defer rows.Close() var offenders []map[string]interface{} for rows.Next() { var ip string var count int var lastSeen time.Time if err := rows.Scan(&ip, &count, &lastSeen); err != nil { continue } offenders = append(offenders, map[string]interface{}{ "ip": ip, "count": count, "last_seen": lastSeen, }) } return offenders, rows.Err() } // GetPatternStats returns statistics on detected patterns func (s *Storage) GetPatternStats(since time.Duration) ([]map[string]interface{}, error) { s.mu.Lock() defer s.mu.Unlock() rows, err := s.db.Query(` SELECT pattern, COUNT(*) as count, COUNT(DISTINCT ip) as unique_ips FROM suspicious_patterns WHERE detected_at > ? GROUP BY pattern ORDER BY count DESC `, time.Now().Add(-since).UTC()) if err != nil { return nil, fmt.Errorf("failed to query pattern stats: %w", err) } defer rows.Close() var stats []map[string]interface{} for rows.Next() { var pattern string var count, uniqueIPs int if err := rows.Scan(&pattern, &count, &uniqueIPs); err != nil { continue } stats = append(stats, map[string]interface{}{ "pattern": pattern, "count": count, "unique_ips": uniqueIPs, }) } return stats, rows.Err() } // Cleanup removes old data func (s *Storage) Cleanup() error { s.mu.Lock() defer s.mu.Unlock() cutoff := time.Now().Add(-s.config.RetainExpired).UTC() // Remove old unbanned entries _, err := s.db.Exec(` DELETE FROM bans WHERE unbanned_at IS NOT NULL AND unbanned_at < ? `, cutoff) if err != nil { return fmt.Errorf("failed to cleanup old bans: %w", err) } // Remove old failures _, err = s.db.Exec(` DELETE FROM failures WHERE recorded_at < ? `, cutoff) if err != nil { return fmt.Errorf("failed to cleanup old failures: %w", err) } // Remove old suspicious patterns _, err = s.db.Exec(` DELETE FROM suspicious_patterns WHERE detected_at < ? `, cutoff) if err != nil { return fmt.Errorf("failed to cleanup old patterns: %w", err) } // Remove old enumeration attempts _, err = s.db.Exec(` DELETE FROM enumeration_attempts WHERE detected_at < ? `, cutoff) if err != nil { return fmt.Errorf("failed to cleanup old enumeration attempts: %w", err) } s.logger.Debug("Storage cleanup completed", zap.Time("cutoff", cutoff)) return nil } // Close closes the database connection func (s *Storage) Close() error { close(s.done) return s.db.Close() }