package sipguardian import ( "net" "sync" "time" "github.com/caddyserver/caddy/v2" "go.uber.org/zap" ) // Global registry to share guardian instances across modules var ( guardianRegistry = make(map[string]*SIPGuardian) registryMu sync.RWMutex ) // GetOrCreateGuardian returns a shared guardian instance by name (backward compat) func GetOrCreateGuardian(ctx caddy.Context, name string) (*SIPGuardian, error) { return GetOrCreateGuardianWithConfig(ctx, name, nil) } // GetOrCreateGuardianWithConfig returns a shared guardian instance, merging config if provided func GetOrCreateGuardianWithConfig(ctx caddy.Context, name string, config *SIPGuardian) (*SIPGuardian, error) { if name == "" { name = "default" } registryMu.Lock() defer registryMu.Unlock() if g, exists := guardianRegistry[name]; exists { // Guardian exists - merge any new config if config != nil { mergeGuardianConfig(ctx, g, config) } return g, nil } // Create new guardian with config var g *SIPGuardian if config != nil { // Copy config values to a new guardian g = &SIPGuardian{ MaxFailures: config.MaxFailures, FindTime: config.FindTime, BanTime: config.BanTime, WhitelistCIDR: config.WhitelistCIDR, Webhooks: config.Webhooks, StoragePath: config.StoragePath, GeoIPPath: config.GeoIPPath, BlockedCountries: config.BlockedCountries, AllowedCountries: config.AllowedCountries, Enumeration: config.Enumeration, } } else { g = &SIPGuardian{} } if err := g.Provision(ctx); err != nil { return nil, err } guardianRegistry[name] = g return g, nil } // mergeGuardianConfig merges new config into an existing guardian // This handles cases where multiple handlers might specify overlapping config func mergeGuardianConfig(ctx caddy.Context, g *SIPGuardian, config *SIPGuardian) { g.mu.Lock() defer g.mu.Unlock() logger := ctx.Logger() // Merge whitelist CIDRs (add new ones, avoid duplicates) for _, cidr := range config.WhitelistCIDR { found := false for _, existing := range g.WhitelistCIDR { if existing == cidr { found = true break } } if !found { g.WhitelistCIDR = append(g.WhitelistCIDR, cidr) // Parse and add to whitelistNets if _, network, err := net.ParseCIDR(cidr); err == nil { g.whitelistNets = append(g.whitelistNets, network) logger.Debug("Added whitelist CIDR from handler config", zap.String("cidr", cidr), ) } } } // Override numeric values if they're non-zero (handler specified them) if config.MaxFailures > 0 && config.MaxFailures != g.MaxFailures { g.MaxFailures = config.MaxFailures } if config.FindTime > 0 && config.FindTime != g.FindTime { g.FindTime = config.FindTime } if config.BanTime > 0 && config.BanTime != g.BanTime { g.BanTime = config.BanTime } // Initialize storage if specified and not yet initialized if config.StoragePath != "" && g.storage == nil { storage, err := InitStorage(logger, StorageConfig{ Path: config.StoragePath, }) if err != nil { logger.Warn("Failed to initialize storage from handler config", zap.Error(err), ) } else { g.storage = storage g.StoragePath = config.StoragePath // Load existing bans from storage if bans, err := storage.LoadActiveBans(); err == nil { for _, ban := range bans { entry := ban g.bannedIPs[entry.IP] = &entry } logger.Info("Loaded bans from storage", zap.Int("count", len(bans))) } } } // Initialize GeoIP if specified and not yet initialized if config.GeoIPPath != "" && g.geoIP == nil { geoIP, err := NewGeoIPLookup(config.GeoIPPath) if err != nil { logger.Warn("Failed to initialize GeoIP from handler config", zap.Error(err), ) } else { g.geoIP = geoIP g.GeoIPPath = config.GeoIPPath } } // Merge blocked/allowed countries for _, country := range config.BlockedCountries { found := false for _, existing := range g.BlockedCountries { if existing == country { found = true break } } if !found { g.BlockedCountries = append(g.BlockedCountries, country) } } for _, country := range config.AllowedCountries { found := false for _, existing := range g.AllowedCountries { if existing == country { found = true break } } if !found { g.AllowedCountries = append(g.AllowedCountries, country) } } // Merge webhooks (add new ones by URL) for _, webhook := range config.Webhooks { found := false for _, existing := range g.Webhooks { if existing.URL == webhook.URL { found = true break } } if !found { g.Webhooks = append(g.Webhooks, webhook) // Register with webhook manager if enableWebhooks { wm := GetWebhookManager(logger) wm.AddWebhook(webhook) } } } // Apply enumeration config if specified if config.Enumeration != nil && g.Enumeration == nil { g.Enumeration = config.Enumeration // Apply to global detector SetEnumerationConfig(*config.Enumeration) logger.Debug("Applied enumeration config from handler") } logger.Debug("Merged guardian config", zap.Int("whitelist_count", len(g.whitelistNets)), zap.Int("webhook_count", len(g.Webhooks)), zap.Duration("ban_time", time.Duration(g.BanTime)), ) } // GetGuardian returns an existing guardian instance func GetGuardian(name string) *SIPGuardian { if name == "" { name = "default" } registryMu.RLock() defer registryMu.RUnlock() return guardianRegistry[name] } // ListGuardians returns all guardian names func ListGuardians() []string { registryMu.RLock() defer registryMu.RUnlock() names := make([]string, 0, len(guardianRegistry)) for name := range guardianRegistry { names = append(names, name) } return names }