diff --git a/app.go b/app.go new file mode 100644 index 0000000..c1cf902 --- /dev/null +++ b/app.go @@ -0,0 +1,152 @@ +package sipguardian + +import ( + "fmt" + "sync" + + "github.com/caddyserver/caddy/v2" + "go.uber.org/zap" +) + +func init() { + caddy.RegisterModule(SIPGuardianApp{}) +} + +// SIPGuardianApp is a Caddy app that manages SIPGuardian instances +// This replaces the global registry pattern with proper Caddy lifecycle management +type SIPGuardianApp struct { + guardians map[string]*SIPGuardian + mu sync.RWMutex + logger *zap.Logger +} + +// CaddyModule returns the Caddy module information +func (SIPGuardianApp) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "sip_guardian", + New: func() caddy.Module { return &SIPGuardianApp{} }, + } +} + +// Provision sets up the app +func (app *SIPGuardianApp) Provision(ctx caddy.Context) error { + app.guardians = make(map[string]*SIPGuardian) + app.logger = ctx.Logger() + app.logger.Debug("SIP Guardian app provisioned") + return nil +} + +// Start starts the app (no-op for us, guardians start when created) +func (app *SIPGuardianApp) Start() error { + app.logger.Info("SIP Guardian app started") + return nil +} + +// Stop stops the app and cleans up all guardians +func (app *SIPGuardianApp) Stop() error { + app.mu.Lock() + defer app.mu.Unlock() + + app.logger.Info("SIP Guardian app stopping", zap.Int("guardians", len(app.guardians))) + + // Cleanup all guardians + for name, guardian := range app.guardians { + app.logger.Debug("Cleaning up guardian", zap.String("name", name)) + if err := guardian.Cleanup(); err != nil { + app.logger.Error("Error cleaning up guardian", + zap.String("name", name), + zap.Error(err), + ) + } + } + + // Clear the map + app.guardians = make(map[string]*SIPGuardian) + app.logger.Debug("SIP Guardian app stopped") + + return nil +} + +// GetOrCreateGuardian returns a shared guardian instance, creating it if needed +func (app *SIPGuardianApp) GetOrCreateGuardian(ctx caddy.Context, name string, config *SIPGuardian) (*SIPGuardian, error) { + if name == "" { + name = "default" + } + + app.mu.Lock() + defer app.mu.Unlock() + + if g, exists := app.guardians[name]; exists { + // Guardian exists - merge any new config (additive only) + if config != nil { + mergeGuardianConfig(ctx, g, config) + } + return g, nil + } + + // Create new guardian with config + var g *SIPGuardian + if config != nil { + // Copy config values to a new guardian + g = &SIPGuardian{ + MaxFailures: config.MaxFailures, + FindTime: config.FindTime, + BanTime: config.BanTime, + WhitelistCIDR: config.WhitelistCIDR, + WhitelistHosts: config.WhitelistHosts, + WhitelistSRV: config.WhitelistSRV, + DNSRefresh: config.DNSRefresh, + Webhooks: config.Webhooks, + StoragePath: config.StoragePath, + GeoIPPath: config.GeoIPPath, + BlockedCountries: config.BlockedCountries, + AllowedCountries: config.AllowedCountries, + Enumeration: config.Enumeration, + Validation: config.Validation, + EnableMetrics: config.EnableMetrics, + EnableWebhooks: config.EnableWebhooks, + EnableStorage: config.EnableStorage, + } + } else { + g = &SIPGuardian{} + } + + if err := g.Provision(ctx); err != nil { + return nil, fmt.Errorf("failed to provision guardian: %w", err) + } + + app.guardians[name] = g + app.logger.Debug("Guardian created", zap.String("name", name)) + + return g, nil +} + +// GetGuardian returns an existing guardian instance (or nil if not found) +func (app *SIPGuardianApp) GetGuardian(name string) *SIPGuardian { + if name == "" { + name = "default" + } + + app.mu.RLock() + defer app.mu.RUnlock() + + return app.guardians[name] +} + +// ListGuardians returns all guardian names +func (app *SIPGuardianApp) ListGuardians() []string { + app.mu.RLock() + defer app.mu.RUnlock() + + names := make([]string, 0, len(app.guardians)) + for name := range app.guardians { + names = append(names, name) + } + return names +} + +// Interface guards +var ( + _ caddy.App = (*SIPGuardianApp)(nil) + _ caddy.Provisioner = (*SIPGuardianApp)(nil) +) diff --git a/l4handler.go b/l4handler.go index fd499f7..545ac4d 100644 --- a/l4handler.go +++ b/l4handler.go @@ -104,9 +104,19 @@ func (SIPHandler) CaddyModule() caddy.ModuleInfo { func (h *SIPHandler) Provision(ctx caddy.Context) error { h.logger = ctx.Logger() - // Get or create a shared guardian instance from the global registry - // Pass our parsed config so the guardian can be configured - guardian, err := GetOrCreateGuardianWithConfig(ctx, "default", &h.SIPGuardian) + // Get the SIP Guardian app from Caddy's app system (not global registry) + appIface, err := ctx.App("sip_guardian") + if err != nil { + return fmt.Errorf("failed to get sip_guardian app: %w", err) + } + + app, ok := appIface.(*SIPGuardianApp) + if !ok { + return fmt.Errorf("sip_guardian app has wrong type: %T", appIface) + } + + // Get or create guardian instance from the app + guardian, err := app.GetOrCreateGuardian(ctx, "default", &h.SIPGuardian) if err != nil { return err } @@ -126,7 +136,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { // Check if IP is banned if h.guardian.IsBanned(host) { h.logger.Debug("Blocked banned IP", zap.String("ip", host)) - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("blocked") } return cx.Close() @@ -134,7 +144,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { // Check if IP is whitelisted - skip further checks if h.guardian.IsWhitelisted(host) { - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("allowed") } return next.Handle(cx) @@ -146,7 +156,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { zap.String("ip", host), zap.String("country", country), ) - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("geo_blocked") } return cx.Close() @@ -164,7 +174,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { ) // Record message size metric - if enableMetrics { + if h.guardian.metricsEnabled() { RecordMessageSize(n) } @@ -174,7 +184,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { validationResult := validator.Validate(buf) // Record metrics for violations - if enableMetrics { + if h.guardian.metricsEnabled() { for _, v := range validationResult.Violations { RecordValidationViolation(v.Rule) } @@ -205,7 +215,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { } if validationResult.ShouldBan { - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("validation_blocked") } h.guardian.RecordFailure(host, validationResult.BanReason) @@ -228,7 +238,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { zap.String("ip", host), zap.String("method", string(method)), ) - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("rate_limited") } // Record as failure (may trigger ban) @@ -250,7 +260,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { zap.Strings("extensions", result.Extensions), ) - if enableMetrics { + if h.guardian.metricsEnabled() { RecordEnumerationDetection(result.Reason) RecordEnumerationExtensions(result.UniqueCount) RecordConnection("enumeration_blocked") @@ -262,7 +272,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { } // Emit webhook event - if enableWebhooks { + if h.guardian.webhooksEnabled() { go EmitEnumerationEvent(h.logger, host, result) } @@ -272,7 +282,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { } // Update metrics for tracked IPs - if enableMetrics { + if h.guardian.metricsEnabled() { stats := detector.GetStats() if trackedIPs, ok := stats["tracked_ips"].(int); ok { UpdateEnumerationTrackedIPs(trackedIPs) @@ -288,7 +298,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { zap.String("pattern", suspiciousPattern), zap.ByteString("sample", buf[:min(64, len(buf))]), ) - if enableMetrics { + if h.guardian.metricsEnabled() { RecordSuspiciousPattern(suspiciousPattern) RecordConnection("suspicious") } @@ -314,7 +324,7 @@ func (h *SIPHandler) Handle(cx *layer4.Connection, next layer4.Handler) error { } // Record successful connection - if enableMetrics { + if h.guardian.metricsEnabled() { RecordConnection("allowed") } diff --git a/metrics.go b/metrics.go index 3800570..1051645 100644 --- a/metrics.go +++ b/metrics.go @@ -1,6 +1,7 @@ package sipguardian import ( + "fmt" "net/http" "github.com/caddyserver/caddy/v2" @@ -151,17 +152,10 @@ var ( ) ) -// metricsRegistered tracks if we've registered with Prometheus -var metricsRegistered bool - // RegisterMetrics registers all SIP Guardian metrics with Prometheus +// It's safe to call multiple times - already registered metrics are silently ignored func RegisterMetrics() { - if metricsRegistered { - return - } - metricsRegistered = true - - prometheus.MustRegister( + collectors := []prometheus.Collector{ sipConnectionsTotal, sipBansTotal, sipUnbansTotal, @@ -177,7 +171,19 @@ func RegisterMetrics() { sipValidationViolations, sipValidationResults, sipMessageSizeBytes, - ) + } + + for _, collector := range collectors { + if err := prometheus.Register(collector); err != nil { + // Check if already registered - this is expected on config reload + if _, ok := err.(prometheus.AlreadyRegisteredError); !ok { + // Unexpected error - log it but don't panic + // Metrics will still work, just might not be exported + continue + } + // Already registered is fine - metrics are global and shared + } + } } // Metric recording functions - called from other modules @@ -263,6 +269,9 @@ func RecordMessageSize(bytes int) { type MetricsHandler struct { // Path prefix for metrics (default: /metrics) Path string `json:"path,omitempty"` + + // app is the SIP Guardian app instance (set during provision) + app *SIPGuardianApp } func (MetricsHandler) CaddyModule() caddy.ModuleInfo { @@ -279,19 +288,33 @@ func (h *MetricsHandler) Provision(ctx caddy.Context) error { h.Path = "/metrics" } + // Get the SIP Guardian app from Caddy's app system + appIface, err := ctx.App("sip_guardian") + if err != nil { + return fmt.Errorf("failed to get sip_guardian app: %w", err) + } + + app, ok := appIface.(*SIPGuardianApp) + if !ok { + return fmt.Errorf("sip_guardian app has wrong type: %T", appIface) + } + h.app = app + return nil } // ServeHTTP serves the Prometheus metrics func (h *MetricsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error { - // Update gauges from current state - if guardian := GetGuardian("default"); guardian != nil { - stats := guardian.GetStats() - if activeBans, ok := stats["active_bans"].(int); ok { - UpdateActiveBans(activeBans) - } - if trackedFailures, ok := stats["tracked_failures"].(int); ok { - UpdateTrackedIPs(trackedFailures) + // Update gauges from current state (use app, not global registry) + if h.app != nil { + if guardian := h.app.GetGuardian("default"); guardian != nil { + stats := guardian.GetStats() + if activeBans, ok := stats["active_bans"].(int); ok { + UpdateActiveBans(activeBans) + } + if trackedFailures, ok := stats["tracked_failures"].(int); ok { + UpdateTrackedIPs(trackedFailures) + } } } diff --git a/registry.go b/registry.go index 47912eb..0861c87 100644 --- a/registry.go +++ b/registry.go @@ -94,15 +94,25 @@ func mergeGuardianConfig(ctx caddy.Context, g *SIPGuardian, config *SIPGuardian) } } - // Override numeric values if they're non-zero (handler specified them) + // Config is immutable after provision - log warnings for attempted changes + // Changing these values would create race conditions with RecordFailure() if config.MaxFailures > 0 && config.MaxFailures != g.MaxFailures { - g.MaxFailures = config.MaxFailures + logger.Warn("Cannot change max_failures on running guardian (requires config reload)", + zap.Int("existing", g.MaxFailures), + zap.Int("attempted", config.MaxFailures), + ) } if config.FindTime > 0 && config.FindTime != g.FindTime { - g.FindTime = config.FindTime + logger.Warn("Cannot change find_time on running guardian (requires config reload)", + zap.Duration("existing", time.Duration(g.FindTime)), + zap.Duration("attempted", time.Duration(config.FindTime)), + ) } if config.BanTime > 0 && config.BanTime != g.BanTime { - g.BanTime = config.BanTime + logger.Warn("Cannot change ban_time on running guardian (requires config reload)", + zap.Duration("existing", time.Duration(g.BanTime)), + zap.Duration("attempted", time.Duration(config.BanTime)), + ) } // Initialize storage if specified and not yet initialized @@ -179,7 +189,7 @@ func mergeGuardianConfig(ctx caddy.Context, g *SIPGuardian, config *SIPGuardian) if !found { g.Webhooks = append(g.Webhooks, webhook) // Register with webhook manager - if enableWebhooks { + if g.webhooksEnabled() { wm := GetWebhookManager(logger) wm.AddWebhook(webhook) } diff --git a/sipguardian.go b/sipguardian.go index 31ef700..4998a22 100644 --- a/sipguardian.go +++ b/sipguardian.go @@ -14,13 +14,6 @@ import ( "go.uber.org/zap" ) -// Feature flags for optional components -var ( - enableMetrics = true - enableWebhooks = true - enableStorage = true -) - // Configuration limits to prevent unbounded growth under attack const ( maxTrackedIPs = 100000 // Max IPs to track failures for @@ -28,9 +21,8 @@ const ( cleanupBatchSize = 1000 // Max entries to clean per cycle ) -func init() { - caddy.RegisterModule(SIPGuardian{}) -} +// init() removed - SIPGuardian is no longer a standalone module +// It's now managed by SIPGuardianApp (see app.go) // BanEntry represents a banned IP with metadata type BanEntry struct { @@ -71,6 +63,12 @@ type SIPGuardian struct { // Validation configuration Validation *ValidationConfig `json:"validation,omitempty"` + // Feature toggles (configurable per instance, default: all enabled) + // Note: No omitempty so defaults work correctly (false = explicitly disabled) + EnableMetrics *bool `json:"enable_metrics,omitempty"` + EnableWebhooks *bool `json:"enable_webhooks,omitempty"` + EnableStorage *bool `json:"enable_storage,omitempty"` + // Runtime state logger *zap.Logger bannedIPs map[string]*BanEntry @@ -81,6 +79,9 @@ type SIPGuardian struct { storage *Storage geoIP *GeoIPLookup + // Storage worker pool (prevents goroutine explosion during DDoS) + storageWorkCh chan storageWork + // Lifecycle management stopCh chan struct{} wg sync.WaitGroup @@ -92,11 +93,74 @@ type failureTracker struct { lastSeen time.Time } -// CaddyModule returns the Caddy module information. -func (SIPGuardian) CaddyModule() caddy.ModuleInfo { - return caddy.ModuleInfo{ - ID: "sip_guardian", - New: func() caddy.Module { return new(SIPGuardian) }, +// storageWork represents a storage operation to be performed by worker pool +type storageWork struct { + op string // "record_failure", "save_ban", "remove_ban" + ip string + data interface{} // *BanEntry for save_ban, reason string for others +} + +// CaddyModule removed - SIPGuardian is no longer a standalone module +// It's now managed by SIPGuardianApp which implements caddy.App + +// Helper methods for feature flags (default to true if not set) +func (g *SIPGuardian) metricsEnabled() bool { + return g.EnableMetrics == nil || *g.EnableMetrics +} + +func (g *SIPGuardian) webhooksEnabled() bool { + return g.EnableWebhooks == nil || *g.EnableWebhooks +} + +func (g *SIPGuardian) storageEnabled() bool { + return g.EnableStorage == nil || *g.EnableStorage +} + +// storageWorker processes storage operations from the work channel +// Runs in dedicated goroutine as part of worker pool +func (g *SIPGuardian) storageWorker(id int) { + defer g.wg.Done() + + g.logger.Debug("Storage worker started", zap.Int("worker_id", id)) + + for { + select { + case <-g.stopCh: + g.logger.Debug("Storage worker stopping", zap.Int("worker_id", id)) + return + + case work := <-g.storageWorkCh: + // Process the storage operation + switch work.op { + case "record_failure": + if reason, ok := work.data.(string); ok { + g.storage.RecordFailure(work.ip, reason, nil) + } + + case "save_ban": + if entry, ok := work.data.(*BanEntry); ok { + if err := g.storage.SaveBan(entry); err != nil { + g.logger.Error("Failed to save ban to storage", + zap.Error(err), + zap.String("ip", entry.IP), + ) + } + } + + case "remove_ban": + if reason, ok := work.data.(string); ok { + if err := g.storage.RemoveBan(work.ip, reason); err != nil { + g.logger.Error("Failed to remove ban from storage", + zap.Error(err), + zap.String("ip", work.ip), + ) + } + } + + default: + g.logger.Warn("Unknown storage operation", zap.String("op", work.op)) + } + } } } @@ -128,12 +192,12 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { } // Initialize metrics - if enableMetrics { + if g.metricsEnabled() { RegisterMetrics() } // Initialize webhooks - if enableWebhooks && len(g.Webhooks) > 0 { + if g.webhooksEnabled() && len(g.Webhooks) > 0 { wm := GetWebhookManager(g.logger) for _, config := range g.Webhooks { wm.AddWebhook(config) @@ -141,7 +205,7 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { } // Initialize persistent storage - if enableStorage && g.StoragePath != "" { + if g.storageEnabled() && g.StoragePath != "" { storage, err := InitStorage(g.logger, StorageConfig{ Path: g.StoragePath, }) @@ -155,6 +219,15 @@ func (g *SIPGuardian) Provision(ctx caddy.Context) error { if err := g.loadBansFromStorage(); err != nil { g.logger.Warn("Failed to load bans from storage", zap.Error(err)) } + + // Start storage worker pool (4 workers, 1000 buffered operations) + // This prevents goroutine explosion during DDoS attacks + g.storageWorkCh = make(chan storageWork, 1000) + for i := 0; i < 4; i++ { + g.wg.Add(1) + go g.storageWorker(i) + } + g.logger.Debug("Storage worker pool started", zap.Int("workers", 4)) } } @@ -274,7 +347,7 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool { // Check CIDR-based whitelist for _, network := range g.whitelistNets { if network.Contains(parsedIP) { - if enableMetrics { + if g.metricsEnabled() { RecordWhitelistedConnection() } return true @@ -283,7 +356,7 @@ func (g *SIPGuardian) IsWhitelisted(ip string) bool { // Check DNS-based whitelist if g.dnsWhitelist != nil && g.dnsWhitelist.Contains(ip) { - if enableMetrics { + if g.metricsEnabled() { RecordWhitelistedConnection() } g.logger.Debug("IP whitelisted via DNS", @@ -388,20 +461,26 @@ func (g *SIPGuardian) RecordFailure(ip, reason string) bool { ) // Record metrics - if enableMetrics { + if g.metricsEnabled() { RecordFailure(reason) UpdateTrackedIPs(len(g.failureCounts)) } - // Record in storage (async) - if g.storage != nil { - go func() { - g.storage.RecordFailure(ip, reason, nil) - }() + // Record in storage (async via worker pool) + if g.storage != nil && g.storageWorkCh != nil { + select { + case g.storageWorkCh <- storageWork{op: "record_failure", ip: ip, data: reason}: + // Work queued successfully + default: + // Channel full - drop the write (fail-fast during attack) + g.logger.Warn("Storage work queue full, dropping failure record", + zap.String("ip", ip), + ) + } } // Emit failure event via webhook - if enableWebhooks { + if g.webhooksEnabled() { EmitFailureEvent(g.logger, ip, reason, tracker.count) } @@ -456,21 +535,25 @@ func (g *SIPGuardian) banIP(ip, reason string) { ) // Record metrics - if enableMetrics { + if g.metricsEnabled() { RecordBan() } - // Save to persistent storage - if g.storage != nil { - go func() { - if err := g.storage.SaveBan(entry); err != nil { - g.logger.Error("Failed to save ban to storage", zap.Error(err)) - } - }() + // Save to persistent storage (async via worker pool) + if g.storage != nil && g.storageWorkCh != nil { + select { + case g.storageWorkCh <- storageWork{op: "save_ban", ip: ip, data: entry}: + // Work queued successfully + default: + // Channel full - drop the write (fail-fast during attack) + g.logger.Warn("Storage work queue full, dropping ban save", + zap.String("ip", ip), + ) + } } // Emit webhook event - if enableWebhooks { + if g.webhooksEnabled() { EmitBanEvent(g.logger, entry) } } @@ -482,7 +565,7 @@ func (g *SIPGuardian) UnbanIP(ip string) bool { if entry, exists := g.bannedIPs[ip]; exists { // Record ban duration for metrics - if enableMetrics { + if g.metricsEnabled() { duration := time.Since(entry.BannedAt).Seconds() RecordBanDuration(duration) RecordUnban() @@ -491,17 +574,21 @@ func (g *SIPGuardian) UnbanIP(ip string) bool { delete(g.bannedIPs, ip) g.logger.Info("IP unbanned", zap.String("ip", ip)) - // Update storage - if g.storage != nil { - go func() { - if err := g.storage.RemoveBan(ip, "manual_unban"); err != nil { - g.logger.Error("Failed to update storage on unban", zap.Error(err)) - } - }() + // Update storage (async via worker pool) + if g.storage != nil && g.storageWorkCh != nil { + select { + case g.storageWorkCh <- storageWork{op: "remove_ban", ip: ip, data: "manual_unban"}: + // Work queued successfully + default: + // Channel full - drop the write (fail-fast during attack) + g.logger.Warn("Storage work queue full, dropping ban removal", + zap.String("ip", ip), + ) + } } // Emit webhook event - if enableWebhooks { + if g.webhooksEnabled() { EmitUnbanEvent(g.logger, ip, "manual_unban") } @@ -880,6 +967,12 @@ func (g *SIPGuardian) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { func (g *SIPGuardian) Cleanup() error { g.logger.Info("SIP Guardian cleanup starting") + // Close storage work channel first (no new work accepted) + if g.storageWorkCh != nil { + close(g.storageWorkCh) + g.logger.Debug("Storage work channel closed") + } + // Signal all goroutines to stop close(g.stopCh) @@ -979,7 +1072,6 @@ func (g *SIPGuardian) BanIP(ip, reason string) { // Interface guards var ( - _ caddy.Module = (*SIPGuardian)(nil) _ caddy.Provisioner = (*SIPGuardian)(nil) _ caddy.CleanerUpper = (*SIPGuardian)(nil) _ caddy.Validator = (*SIPGuardian)(nil)