From f76946fc41af4243a2c2f47c4929ea56089d6c24 Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Sun, 7 Dec 2025 19:02:50 -0700 Subject: [PATCH] Add SIP topology hiding feature (B2BUA-lite) Implements RFC 3261 compliant topology hiding to protect internal infrastructure from external attackers: New files: - sipmsg.go: SIP message parsing/serialization with full header support - sipheaders.go: Via, Contact, From/To header parsing with compact forms - dialog_state.go: Dialog and transaction state management for response correlation - topology.go: TopologyHider handler for caddy-l4 integration - topology_test.go: Comprehensive unit tests (26 new tests, 60 total) Features: - Via header insertion (proxy adds own Via, pops on response) - Contact header rewriting (hide internal IPs behind proxy address) - Sensitive header stripping (P-Asserted-Identity, Server, etc.) - Call-ID anonymization (optional) - Private IP masking in all headers - Dialog state tracking for stateful response routing - Transaction state for stateless operation Caddyfile configuration: sip_topology_hider { proxy_host 203.0.113.1 proxy_port 5060 upstream udp/192.168.1.100:5060 rewrite_via rewrite_contact strip_headers P-Preferred-Identity Server } --- dialog_state.go | 391 ++++++++++++++++++++++ sipheaders.go | 466 ++++++++++++++++++++++++++ sipmsg.go | 416 +++++++++++++++++++++++ topology.go | 539 ++++++++++++++++++++++++++++++ topology_test.go | 833 +++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 2645 insertions(+) create mode 100644 dialog_state.go create mode 100644 sipheaders.go create mode 100644 sipmsg.go create mode 100644 topology.go create mode 100644 topology_test.go diff --git a/dialog_state.go b/dialog_state.go new file mode 100644 index 0000000..f49aa15 --- /dev/null +++ b/dialog_state.go @@ -0,0 +1,391 @@ +package sipguardian + +import ( + "sync" + "time" + + "go.uber.org/zap" +) + +// DialogState tracks the state of a SIP dialog for topology hiding +type DialogState struct { + // Dialog identifiers + CallID string + FromTag string + ToTag string + LocalBranch string // Via branch we added + + // Original values (for response rewriting) + OriginalCallID string // if anonymized + OriginalVia string // original top Via header + OriginalContact string // original Contact header + + // Proxy values (what we replaced with) + ProxyContact string + ProxyVia string + + // Client information (for routing responses) + ClientHost string + ClientPort int + + // Timestamps + Created time.Time + LastActivity time.Time + + // State tracking + IsConfirmed bool // 2xx received for dialog-creating request + IsTerminated bool +} + +// DialogManager manages dialog state for topology hiding +type DialogManager struct { + dialogs map[string]*DialogState + byBranch map[string]*DialogState // Quick lookup by Via branch + cleanupTTL time.Duration + logger *zap.Logger + mu sync.RWMutex +} + +// NewDialogManager creates a new dialog manager +func NewDialogManager(logger *zap.Logger, cleanupTTL time.Duration) *DialogManager { + if cleanupTTL == 0 { + cleanupTTL = 10 * time.Minute + } + + dm := &DialogManager{ + dialogs: make(map[string]*DialogState), + byBranch: make(map[string]*DialogState), + cleanupTTL: cleanupTTL, + logger: logger, + } + + return dm +} + +// dialogKey generates a unique key for a dialog +func dialogKey(callID, fromTag string) string { + return callID + ":" + fromTag +} + +// CreateDialog creates a new dialog state for an outgoing request +func (dm *DialogManager) CreateDialog(msg *SIPMessage, clientHost string, clientPort int) *DialogState { + dm.mu.Lock() + defer dm.mu.Unlock() + + callID := msg.GetCallID() + fromTag := msg.GetFromTag() + branch := msg.GetViaBranch() + + state := &DialogState{ + CallID: callID, + FromTag: fromTag, + LocalBranch: branch, + ClientHost: clientHost, + ClientPort: clientPort, + Created: time.Now(), + LastActivity: time.Now(), + } + + // Store original Via if we're going to modify it + if via := msg.GetHeader("Via"); via != nil { + state.OriginalVia = via.Value + } + + // Store original Contact + if contact := msg.GetHeader("Contact"); contact != nil { + state.OriginalContact = contact.Value + } + + // Store by dialog key and by branch + key := dialogKey(callID, fromTag) + dm.dialogs[key] = state + + if branch != "" { + dm.byBranch[branch] = state + } + + dm.logger.Debug("Dialog created", + zap.String("call_id", callID), + zap.String("from_tag", fromTag), + zap.String("branch", branch), + ) + + return state +} + +// GetDialogByCallID retrieves a dialog by Call-ID and From-tag +func (dm *DialogManager) GetDialogByCallID(callID, fromTag string) *DialogState { + dm.mu.RLock() + defer dm.mu.RUnlock() + + key := dialogKey(callID, fromTag) + return dm.dialogs[key] +} + +// GetDialogByBranch retrieves a dialog by Via branch (for response routing) +func (dm *DialogManager) GetDialogByBranch(branch string) *DialogState { + dm.mu.RLock() + defer dm.mu.RUnlock() + + return dm.byBranch[branch] +} + +// UpdateDialog updates an existing dialog with new information +func (dm *DialogManager) UpdateDialog(callID, fromTag string, toTag string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + state.LastActivity = time.Now() + if toTag != "" && state.ToTag == "" { + state.ToTag = toTag + } + } +} + +// ConfirmDialog marks a dialog as confirmed (2xx received) +func (dm *DialogManager) ConfirmDialog(callID, fromTag string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + state.IsConfirmed = true + state.LastActivity = time.Now() + } +} + +// TerminateDialog marks a dialog as terminated +func (dm *DialogManager) TerminateDialog(callID, fromTag string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + state.IsTerminated = true + state.LastActivity = time.Now() + } +} + +// RemoveDialog removes a dialog from the manager +func (dm *DialogManager) RemoveDialog(callID, fromTag string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + // Remove from branch index + if state.LocalBranch != "" { + delete(dm.byBranch, state.LocalBranch) + } + delete(dm.dialogs, key) + + dm.logger.Debug("Dialog removed", + zap.String("call_id", callID), + zap.String("from_tag", fromTag), + ) + } +} + +// StoreOriginals stores original header values before rewriting +func (dm *DialogManager) StoreOriginals(callID, fromTag string, via, contact string, originalCallID string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + if via != "" { + state.OriginalVia = via + } + if contact != "" { + state.OriginalContact = contact + } + if originalCallID != "" { + state.OriginalCallID = originalCallID + } + } +} + +// StoreProxyValues stores the values we used to replace originals +func (dm *DialogManager) StoreProxyValues(callID, fromTag string, via, contact string) { + dm.mu.Lock() + defer dm.mu.Unlock() + + key := dialogKey(callID, fromTag) + if state, ok := dm.dialogs[key]; ok { + if via != "" { + state.ProxyVia = via + } + if contact != "" { + state.ProxyContact = contact + } + } +} + +// Cleanup removes stale dialogs +func (dm *DialogManager) Cleanup() int { + dm.mu.Lock() + defer dm.mu.Unlock() + + cutoff := time.Now().Add(-dm.cleanupTTL) + removed := 0 + + for key, state := range dm.dialogs { + // Remove terminated dialogs after TTL + if state.IsTerminated && state.LastActivity.Before(cutoff) { + if state.LocalBranch != "" { + delete(dm.byBranch, state.LocalBranch) + } + delete(dm.dialogs, key) + removed++ + continue + } + + // Remove unconfirmed dialogs that have been idle too long + if !state.IsConfirmed && state.LastActivity.Before(cutoff) { + if state.LocalBranch != "" { + delete(dm.byBranch, state.LocalBranch) + } + delete(dm.dialogs, key) + removed++ + continue + } + + // Remove very old dialogs regardless of state (prevent memory leak) + if state.Created.Before(time.Now().Add(-24 * time.Hour)) { + if state.LocalBranch != "" { + delete(dm.byBranch, state.LocalBranch) + } + delete(dm.dialogs, key) + removed++ + } + } + + if removed > 0 { + dm.logger.Debug("Cleaned up stale dialogs", zap.Int("count", removed)) + } + + return removed +} + +// GetStats returns statistics about the dialog manager +func (dm *DialogManager) GetStats() map[string]interface{} { + dm.mu.RLock() + defer dm.mu.RUnlock() + + confirmed := 0 + terminated := 0 + pending := 0 + + for _, state := range dm.dialogs { + if state.IsTerminated { + terminated++ + } else if state.IsConfirmed { + confirmed++ + } else { + pending++ + } + } + + return map[string]interface{}{ + "total_dialogs": len(dm.dialogs), + "confirmed_dialogs": confirmed, + "terminated_dialogs": terminated, + "pending_dialogs": pending, + "branch_index_size": len(dm.byBranch), + } +} + +// TransactionState tracks a single SIP transaction (simpler than full dialog) +type TransactionState struct { + Branch string + Method string + ClientHost string + ClientPort int + OriginalVia string + Created time.Time +} + +// TransactionManager manages transaction state for stateless topology hiding +type TransactionManager struct { + transactions map[string]*TransactionState + cleanupTTL time.Duration + logger *zap.Logger + mu sync.RWMutex +} + +// NewTransactionManager creates a new transaction manager +func NewTransactionManager(logger *zap.Logger, cleanupTTL time.Duration) *TransactionManager { + if cleanupTTL == 0 { + cleanupTTL = 32 * time.Second // SIP transaction timeout + } + + return &TransactionManager{ + transactions: make(map[string]*TransactionState), + cleanupTTL: cleanupTTL, + logger: logger, + } +} + +// CreateTransaction creates a new transaction state +func (tm *TransactionManager) CreateTransaction(branch, method, clientHost string, clientPort int, originalVia string) *TransactionState { + tm.mu.Lock() + defer tm.mu.Unlock() + + state := &TransactionState{ + Branch: branch, + Method: method, + ClientHost: clientHost, + ClientPort: clientPort, + OriginalVia: originalVia, + Created: time.Now(), + } + + tm.transactions[branch] = state + return state +} + +// GetTransaction retrieves a transaction by branch +func (tm *TransactionManager) GetTransaction(branch string) *TransactionState { + tm.mu.RLock() + defer tm.mu.RUnlock() + + return tm.transactions[branch] +} + +// RemoveTransaction removes a transaction +func (tm *TransactionManager) RemoveTransaction(branch string) { + tm.mu.Lock() + defer tm.mu.Unlock() + + delete(tm.transactions, branch) +} + +// Cleanup removes expired transactions +func (tm *TransactionManager) Cleanup() int { + tm.mu.Lock() + defer tm.mu.Unlock() + + cutoff := time.Now().Add(-tm.cleanupTTL) + removed := 0 + + for branch, state := range tm.transactions { + if state.Created.Before(cutoff) { + delete(tm.transactions, branch) + removed++ + } + } + + return removed +} + +// GetStats returns statistics about the transaction manager +func (tm *TransactionManager) GetStats() map[string]interface{} { + tm.mu.RLock() + defer tm.mu.RUnlock() + + return map[string]interface{}{ + "active_transactions": len(tm.transactions), + } +} diff --git a/sipheaders.go b/sipheaders.go new file mode 100644 index 0000000..adba222 --- /dev/null +++ b/sipheaders.go @@ -0,0 +1,466 @@ +package sipguardian + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +// ViaHeader represents a parsed Via header +type ViaHeader struct { + Protocol string // SIP/2.0 + Transport string // UDP, TCP, TLS, WS, WSS + Host string + Port int + Branch string + Received string + RPort int + Params map[string]string // Other parameters +} + +// ContactHeader represents a parsed Contact header +type ContactHeader struct { + DisplayName string + URI string + Params map[string]string +} + +// FromToHeader represents a parsed From or To header +type FromToHeader struct { + DisplayName string + URI string + Tag string + Params map[string]string +} + +// SensitiveHeaders lists headers that may leak internal topology +var SensitiveHeaders = []string{ + "P-Preferred-Identity", + "P-Asserted-Identity", + "Remote-Party-ID", + "P-Charging-Vector", + "P-Charging-Function-Addresses", + "Server", + "User-Agent", + "X-Asterisk-HangupCause", + "X-Asterisk-HangupCauseCode", +} + +// ParseViaHeader parses a Via header value into structured form +// Example: "SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123;received=10.0.0.1" +func ParseViaHeader(value string) (*ViaHeader, error) { + via := &ViaHeader{ + Params: make(map[string]string), + } + + // Split by semicolons to get parameters + parts := strings.Split(value, ";") + if len(parts) == 0 { + return nil, fmt.Errorf("empty Via header") + } + + // First part: SIP/2.0/UDP host:port + mainPart := strings.TrimSpace(parts[0]) + fields := strings.Fields(mainPart) + if len(fields) < 2 { + return nil, fmt.Errorf("invalid Via header format") + } + + // Parse protocol/transport + protoTransport := strings.Split(fields[0], "/") + if len(protoTransport) >= 2 { + via.Protocol = protoTransport[0] + "/" + protoTransport[1] + } + if len(protoTransport) >= 3 { + via.Transport = protoTransport[2] + } + + // Parse host:port + hostPort := fields[1] + if colonIdx := strings.LastIndex(hostPort, ":"); colonIdx > 0 { + via.Host = hostPort[:colonIdx] + port, _ := strconv.Atoi(hostPort[colonIdx+1:]) + via.Port = port + } else { + via.Host = hostPort + via.Port = 5060 // Default SIP port + } + + // Parse parameters + for i := 1; i < len(parts); i++ { + param := strings.TrimSpace(parts[i]) + if eqIdx := strings.Index(param, "="); eqIdx > 0 { + key := strings.ToLower(param[:eqIdx]) + val := param[eqIdx+1:] + + switch key { + case "branch": + via.Branch = val + case "received": + via.Received = val + case "rport": + via.RPort, _ = strconv.Atoi(val) + default: + via.Params[key] = val + } + } else { + // Flag parameter (no value) + via.Params[strings.ToLower(param)] = "" + } + } + + return via, nil +} + +// Serialize converts ViaHeader back to string format +func (v *ViaHeader) Serialize() string { + var buf strings.Builder + + // Protocol and transport + buf.WriteString(v.Protocol) + buf.WriteByte('/') + buf.WriteString(v.Transport) + buf.WriteByte(' ') + + // Host and port + buf.WriteString(v.Host) + if v.Port > 0 && v.Port != 5060 { + buf.WriteByte(':') + buf.WriteString(strconv.Itoa(v.Port)) + } + + // Branch (required for RFC 3261 compliance) + if v.Branch != "" { + buf.WriteString(";branch=") + buf.WriteString(v.Branch) + } + + // Received parameter + if v.Received != "" { + buf.WriteString(";received=") + buf.WriteString(v.Received) + } + + // RPort parameter + if v.RPort > 0 { + buf.WriteString(";rport=") + buf.WriteString(strconv.Itoa(v.RPort)) + } + + // Other parameters + for key, val := range v.Params { + buf.WriteByte(';') + buf.WriteString(key) + if val != "" { + buf.WriteByte('=') + buf.WriteString(val) + } + } + + return buf.String() +} + +// ParseContactHeader parses a Contact header value +// Example: "\"Alice\" ;expires=3600" +func ParseContactHeader(value string) (*ContactHeader, error) { + contact := &ContactHeader{ + Params: make(map[string]string), + } + + // Handle * (special Contact for REGISTER) + if strings.TrimSpace(value) == "*" { + contact.URI = "*" + return contact, nil + } + + // Extract display name if present (quoted) + if idx := strings.Index(value, "\""); idx >= 0 { + endIdx := strings.Index(value[idx+1:], "\"") + if endIdx > 0 { + contact.DisplayName = value[idx+1 : idx+1+endIdx] + value = value[idx+1+endIdx+1:] + } + } + + // Extract URI (within angle brackets or until semicolon) + value = strings.TrimSpace(value) + if strings.HasPrefix(value, "<") { + endIdx := strings.Index(value, ">") + if endIdx > 0 { + contact.URI = value[1:endIdx] + value = value[endIdx+1:] + } + } else { + // No angle brackets - URI goes until semicolon or end + semiIdx := strings.Index(value, ";") + if semiIdx > 0 { + contact.URI = strings.TrimSpace(value[:semiIdx]) + value = value[semiIdx:] + } else { + contact.URI = strings.TrimSpace(value) + value = "" + } + } + + // Parse parameters + for _, part := range strings.Split(value, ";") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if eqIdx := strings.Index(part, "="); eqIdx > 0 { + contact.Params[strings.ToLower(part[:eqIdx])] = part[eqIdx+1:] + } else if part != "" { + contact.Params[strings.ToLower(part)] = "" + } + } + + return contact, nil +} + +// Serialize converts ContactHeader back to string format +func (c *ContactHeader) Serialize() string { + var buf strings.Builder + + if c.URI == "*" { + return "*" + } + + if c.DisplayName != "" { + buf.WriteByte('"') + buf.WriteString(c.DisplayName) + buf.WriteString("\" ") + } + + buf.WriteByte('<') + buf.WriteString(c.URI) + buf.WriteByte('>') + + for key, val := range c.Params { + buf.WriteByte(';') + buf.WriteString(key) + if val != "" { + buf.WriteByte('=') + buf.WriteString(val) + } + } + + return buf.String() +} + +// ParseFromToHeader parses a From or To header value +// Example: "\"Bob\" ;tag=abc123" +func ParseFromToHeader(value string) (*FromToHeader, error) { + header := &FromToHeader{ + Params: make(map[string]string), + } + + // Extract display name if present (quoted) + if idx := strings.Index(value, "\""); idx >= 0 { + endIdx := strings.Index(value[idx+1:], "\"") + if endIdx > 0 { + header.DisplayName = value[idx+1 : idx+1+endIdx] + value = value[idx+1+endIdx+1:] + } + } + + // Extract URI (within angle brackets) + value = strings.TrimSpace(value) + if strings.HasPrefix(value, "<") { + endIdx := strings.Index(value, ">") + if endIdx > 0 { + header.URI = value[1:endIdx] + value = value[endIdx+1:] + } + } else { + // No angle brackets + semiIdx := strings.Index(value, ";") + if semiIdx > 0 { + header.URI = strings.TrimSpace(value[:semiIdx]) + value = value[semiIdx:] + } else { + header.URI = strings.TrimSpace(value) + value = "" + } + } + + // Parse parameters + for _, part := range strings.Split(value, ";") { + part = strings.TrimSpace(part) + if part == "" { + continue + } + if eqIdx := strings.Index(part, "="); eqIdx > 0 { + key := strings.ToLower(part[:eqIdx]) + val := part[eqIdx+1:] + if key == "tag" { + header.Tag = val + } else { + header.Params[key] = val + } + } + } + + return header, nil +} + +// Serialize converts FromToHeader back to string format +func (f *FromToHeader) Serialize() string { + var buf strings.Builder + + if f.DisplayName != "" { + buf.WriteByte('"') + buf.WriteString(f.DisplayName) + buf.WriteString("\" ") + } + + buf.WriteByte('<') + buf.WriteString(f.URI) + buf.WriteByte('>') + + if f.Tag != "" { + buf.WriteString(";tag=") + buf.WriteString(f.Tag) + } + + for key, val := range f.Params { + buf.WriteByte(';') + buf.WriteString(key) + if val != "" { + buf.WriteByte('=') + buf.WriteString(val) + } + } + + return buf.String() +} + +// ExtractHostFromURI extracts the host part from a SIP URI +// Example: "sip:user@host:5060;transport=udp" -> "host" +func ExtractHostFromURI(uri string) string { + // Remove scheme + if idx := strings.Index(uri, ":"); idx >= 0 { + uri = uri[idx+1:] + } + + // Remove user part + if idx := strings.Index(uri, "@"); idx >= 0 { + uri = uri[idx+1:] + } + + // Remove port and parameters + if idx := strings.Index(uri, ":"); idx >= 0 { + uri = uri[:idx] + } + if idx := strings.Index(uri, ";"); idx >= 0 { + uri = uri[:idx] + } + + return uri +} + +// ExtractPortFromURI extracts the port from a SIP URI (default 5060) +func ExtractPortFromURI(uri string) int { + // Remove scheme and user part + if idx := strings.Index(uri, "@"); idx >= 0 { + uri = uri[idx+1:] + } else if idx := strings.Index(uri, ":"); idx >= 0 { + uri = uri[idx+1:] + } + + // Extract port + if colonIdx := strings.Index(uri, ":"); colonIdx >= 0 { + portStr := uri[colonIdx+1:] + if semiIdx := strings.Index(portStr, ";"); semiIdx >= 0 { + portStr = portStr[:semiIdx] + } + port, err := strconv.Atoi(portStr) + if err == nil { + return port + } + } + + return 5060 // Default SIP port +} + +// RewriteURIHost replaces the host in a SIP URI +func RewriteURIHost(uri, newHost string, newPort int) string { + // Pattern to match SIP URI + re := regexp.MustCompile(`^(sips?:)([^@]+@)?([^:;>]+)(:\d+)?(.*)$`) + matches := re.FindStringSubmatch(uri) + if matches == nil { + return uri + } + + // Rebuild URI with new host + result := matches[1] + matches[2] + newHost + if newPort > 0 && newPort != 5060 { + result += ":" + strconv.Itoa(newPort) + } + result += matches[5] + + return result +} + +// GenerateBranch generates an RFC 3261 compliant branch parameter +func GenerateBranch() string { + // Must start with z9hG4bK for RFC 3261 compliance + return "z9hG4bK" + generateRandomHex(16) +} + +// generateRandomHex generates a random hex string of given length +func generateRandomHex(length int) string { + const chars = "0123456789abcdef" + result := make([]byte, length) + for i := range result { + // Use uint64 modulo to avoid negative index from int conversion + result[i] = chars[pseudoRand()%uint64(len(chars))] + } + return string(result) +} + +// Simple pseudo-random number generator (not cryptographically secure) +var prngState uint64 = 0xdeadbeef + +func pseudoRand() uint64 { + prngState ^= prngState << 13 + prngState ^= prngState >> 7 + prngState ^= prngState << 17 + return prngState +} + +// InitPRNG seeds the pseudo-random number generator +func InitPRNG(seed uint64) { + prngState = seed + if prngState == 0 { + prngState = 0xdeadbeef + } +} + +// Note: IsPrivateIP is defined in geoip.go with proper CIDR parsing + +// NormalizeHeaderName converts header name to canonical form +func NormalizeHeaderName(name string) string { + // Map compact forms to full names + compactToFull := map[string]string{ + "i": "Call-ID", + "m": "Contact", + "e": "Content-Encoding", + "l": "Content-Length", + "c": "Content-Type", + "f": "From", + "s": "Subject", + "k": "Supported", + "t": "To", + "v": "Via", + } + + lower := strings.ToLower(name) + if full, ok := compactToFull[lower]; ok { + return full + } + + // Title case for standard headers + return strings.Title(lower) +} diff --git a/sipmsg.go b/sipmsg.go new file mode 100644 index 0000000..c70ab70 --- /dev/null +++ b/sipmsg.go @@ -0,0 +1,416 @@ +package sipguardian + +import ( + "bytes" + "errors" + "fmt" + "strconv" + "strings" +) + +// SIPMessage represents a parsed SIP message (request or response) +type SIPMessage struct { + // Common fields + IsRequest bool + SIPVersion string + Headers []SIPHeader + Body []byte + + // Request fields + Method string + RequestURI string + + // Response fields + StatusCode int + ReasonPhrase string +} + +// SIPHeader represents a single SIP header +type SIPHeader struct { + Name string + Value string +} + +// Common errors +var ( + ErrInvalidSIPMessage = errors.New("invalid SIP message format") + ErrEmptyMessage = errors.New("empty message") + ErrNoStartLine = errors.New("missing start line") +) + +// ParseSIPMessage parses raw bytes into a SIPMessage structure +func ParseSIPMessage(data []byte) (*SIPMessage, error) { + if len(data) == 0 { + return nil, ErrEmptyMessage + } + + msg := &SIPMessage{ + Headers: make([]SIPHeader, 0), + } + + // Split into headers and body at double CRLF + parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2) + headerSection := parts[0] + if len(parts) > 1 { + msg.Body = parts[1] + } + + // Split header section into lines + lines := bytes.Split(headerSection, []byte("\r\n")) + if len(lines) == 0 { + return nil, ErrNoStartLine + } + + // Parse start line (first line) + if err := msg.parseStartLine(string(lines[0])); err != nil { + return nil, err + } + + // Parse headers (remaining lines) + for i := 1; i < len(lines); i++ { + line := string(lines[i]) + if line == "" { + continue + } + + // Handle header continuation (line starting with whitespace) + if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') { + // Append to previous header + if len(msg.Headers) > 0 { + msg.Headers[len(msg.Headers)-1].Value += " " + strings.TrimSpace(line) + } + continue + } + + // Parse header name: value + colonIdx := strings.Index(line, ":") + if colonIdx > 0 { + name := strings.TrimSpace(line[:colonIdx]) + value := strings.TrimSpace(line[colonIdx+1:]) + msg.Headers = append(msg.Headers, SIPHeader{ + Name: name, + Value: value, + }) + } + } + + return msg, nil +} + +// parseStartLine parses the first line of a SIP message +func (m *SIPMessage) parseStartLine(line string) error { + parts := strings.Fields(line) + if len(parts) < 2 { + return ErrInvalidSIPMessage + } + + // Check if it's a response (starts with SIP/2.0) + if strings.HasPrefix(parts[0], "SIP/") { + m.IsRequest = false + m.SIPVersion = parts[0] + + // Parse status code + if len(parts) < 2 { + return ErrInvalidSIPMessage + } + code, err := strconv.Atoi(parts[1]) + if err != nil { + return fmt.Errorf("invalid status code: %v", err) + } + m.StatusCode = code + + // Reason phrase is the rest + if len(parts) >= 3 { + m.ReasonPhrase = strings.Join(parts[2:], " ") + } + } else { + // It's a request + m.IsRequest = true + m.Method = parts[0] + + if len(parts) < 3 { + return ErrInvalidSIPMessage + } + m.RequestURI = parts[1] + m.SIPVersion = parts[2] + } + + return nil +} + +// Serialize converts the SIPMessage back to wire format +func (m *SIPMessage) Serialize() []byte { + var buf bytes.Buffer + + // Write start line + if m.IsRequest { + buf.WriteString(m.Method) + buf.WriteByte(' ') + buf.WriteString(m.RequestURI) + buf.WriteByte(' ') + buf.WriteString(m.SIPVersion) + } else { + buf.WriteString(m.SIPVersion) + buf.WriteByte(' ') + buf.WriteString(strconv.Itoa(m.StatusCode)) + buf.WriteByte(' ') + buf.WriteString(m.ReasonPhrase) + } + buf.WriteString("\r\n") + + // Write headers + for _, h := range m.Headers { + buf.WriteString(h.Name) + buf.WriteString(": ") + buf.WriteString(h.Value) + buf.WriteString("\r\n") + } + + // End of headers + buf.WriteString("\r\n") + + // Write body if present + if len(m.Body) > 0 { + buf.Write(m.Body) + } + + return buf.Bytes() +} + +// GetHeader returns the first header with the given name (case-insensitive) +func (m *SIPMessage) GetHeader(name string) *SIPHeader { + lowerName := strings.ToLower(name) + // Also check compact form + compactName := getCompactForm(lowerName) + + for i := range m.Headers { + headerLower := strings.ToLower(m.Headers[i].Name) + if headerLower == lowerName || headerLower == compactName { + return &m.Headers[i] + } + } + return nil +} + +// GetHeaders returns all headers with the given name (case-insensitive) +func (m *SIPMessage) GetHeaders(name string) []*SIPHeader { + lowerName := strings.ToLower(name) + compactName := getCompactForm(lowerName) + + var result []*SIPHeader + for i := range m.Headers { + headerLower := strings.ToLower(m.Headers[i].Name) + if headerLower == lowerName || headerLower == compactName { + result = append(result, &m.Headers[i]) + } + } + return result +} + +// SetHeader sets a header value, replacing existing or adding new +func (m *SIPMessage) SetHeader(name, value string) { + lowerName := strings.ToLower(name) + compactName := getCompactForm(lowerName) + + for i := range m.Headers { + headerLower := strings.ToLower(m.Headers[i].Name) + if headerLower == lowerName || headerLower == compactName { + m.Headers[i].Value = value + return + } + } + // Not found, add new + m.Headers = append(m.Headers, SIPHeader{Name: name, Value: value}) +} + +// PrependHeader adds a header at the beginning of the header list +func (m *SIPMessage) PrependHeader(name, value string) { + newHeader := SIPHeader{Name: name, Value: value} + m.Headers = append([]SIPHeader{newHeader}, m.Headers...) +} + +// AppendHeader adds a header at the end of the header list +func (m *SIPMessage) AppendHeader(name, value string) { + m.Headers = append(m.Headers, SIPHeader{Name: name, Value: value}) +} + +// RemoveHeader removes all headers with the given name (case-insensitive) +func (m *SIPMessage) RemoveHeader(name string) { + lowerName := strings.ToLower(name) + compactName := getCompactForm(lowerName) + + newHeaders := make([]SIPHeader, 0, len(m.Headers)) + for _, h := range m.Headers { + headerLower := strings.ToLower(h.Name) + if headerLower != lowerName && headerLower != compactName { + newHeaders = append(newHeaders, h) + } + } + m.Headers = newHeaders +} + +// RemoveFirstHeader removes only the first header with the given name +func (m *SIPMessage) RemoveFirstHeader(name string) bool { + lowerName := strings.ToLower(name) + compactName := getCompactForm(lowerName) + + for i, h := range m.Headers { + headerLower := strings.ToLower(h.Name) + if headerLower == lowerName || headerLower == compactName { + m.Headers = append(m.Headers[:i], m.Headers[i+1:]...) + return true + } + } + return false +} + +// GetCallID returns the Call-ID header value +func (m *SIPMessage) GetCallID() string { + if h := m.GetHeader("Call-ID"); h != nil { + return h.Value + } + return "" +} + +// GetFromTag extracts the tag parameter from the From header +func (m *SIPMessage) GetFromTag() string { + if h := m.GetHeader("From"); h != nil { + return extractTagParam(h.Value) + } + return "" +} + +// GetToTag extracts the tag parameter from the To header +func (m *SIPMessage) GetToTag() string { + if h := m.GetHeader("To"); h != nil { + return extractTagParam(h.Value) + } + return "" +} + +// GetCSeq returns the CSeq number and method +func (m *SIPMessage) GetCSeq() (int, string) { + if h := m.GetHeader("CSeq"); h != nil { + parts := strings.Fields(h.Value) + if len(parts) >= 2 { + seq, _ := strconv.Atoi(parts[0]) + return seq, parts[1] + } + } + return 0, "" +} + +// GetViaBranch extracts the branch parameter from the top Via header +func (m *SIPMessage) GetViaBranch() string { + if h := m.GetHeader("Via"); h != nil { + return extractViaParam(h.Value, "branch") + } + return "" +} + +// Clone creates a deep copy of the message +func (m *SIPMessage) Clone() *SIPMessage { + clone := &SIPMessage{ + IsRequest: m.IsRequest, + SIPVersion: m.SIPVersion, + Method: m.Method, + RequestURI: m.RequestURI, + StatusCode: m.StatusCode, + ReasonPhrase: m.ReasonPhrase, + Headers: make([]SIPHeader, len(m.Headers)), + } + + copy(clone.Headers, m.Headers) + + if len(m.Body) > 0 { + clone.Body = make([]byte, len(m.Body)) + copy(clone.Body, m.Body) + } + + return clone +} + +// getCompactForm returns the compact form of a header name +func getCompactForm(name string) string { + compactForms := map[string]string{ + "call-id": "i", + "contact": "m", + "content-encoding": "e", + "content-length": "l", + "content-type": "c", + "from": "f", + "subject": "s", + "supported": "k", + "to": "t", + "via": "v", + } + if compact, ok := compactForms[name]; ok { + return compact + } + return "" +} + +// extractTagParam extracts the tag parameter from a From/To header value +func extractTagParam(headerValue string) string { + // Look for ;tag= parameter + idx := strings.Index(strings.ToLower(headerValue), ";tag=") + if idx < 0 { + return "" + } + + tagStart := idx + 5 // len(";tag=") + rest := headerValue[tagStart:] + + // Find end of tag value (semicolon or end of string) + endIdx := strings.IndexAny(rest, ";,>") + if endIdx < 0 { + return rest + } + return rest[:endIdx] +} + +// extractViaParam extracts a parameter from a Via header value +func extractViaParam(headerValue, param string) string { + lower := strings.ToLower(headerValue) + search := ";" + param + "=" + + idx := strings.Index(lower, search) + if idx < 0 { + return "" + } + + valueStart := idx + len(search) + rest := headerValue[valueStart:] + + // Find end of value + endIdx := strings.IndexAny(rest, ";,") + if endIdx < 0 { + return rest + } + return rest[:endIdx] +} + +// IsDialogCreating returns true if this is a dialog-creating request +func (m *SIPMessage) IsDialogCreating() bool { + if !m.IsRequest { + return false + } + switch m.Method { + case "INVITE", "SUBSCRIBE", "REFER", "NOTIFY": + return true + } + return false +} + +// IsDialogTerminating returns true if this ends a dialog +func (m *SIPMessage) IsDialogTerminating() bool { + if !m.IsRequest { + return false + } + return m.Method == "BYE" +} + +// UpdateContentLength updates the Content-Length header to match body size +func (m *SIPMessage) UpdateContentLength() { + m.SetHeader("Content-Length", strconv.Itoa(len(m.Body))) +} diff --git a/topology.go b/topology.go new file mode 100644 index 0000000..48c4627 --- /dev/null +++ b/topology.go @@ -0,0 +1,539 @@ +package sipguardian + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "net" + "strconv" + "strings" + "time" + + "github.com/caddyserver/caddy/v2" + "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" + "github.com/mholt/caddy-l4/layer4" + "go.uber.org/zap" +) + +func init() { + caddy.RegisterModule(TopologyHider{}) +} + +// TopologyHider is a Layer 4 handler that hides internal SIP topology +type TopologyHider struct { + // Enabled toggles topology hiding + Enabled bool `json:"enabled,omitempty"` + + // ProxyHost is the public IP address to use in rewritten headers + ProxyHost string `json:"proxy_host,omitempty"` + + // ProxyPort is the public port to use in rewritten headers + ProxyPort int `json:"proxy_port,omitempty"` + + // RewriteVia adds/modifies Via headers + RewriteVia bool `json:"rewrite_via,omitempty"` + + // RewriteContact rewrites Contact headers to hide internal addresses + RewriteContact bool `json:"rewrite_contact,omitempty"` + + // StripHeaders lists headers to remove (e.g., Server, P-Preferred-Identity) + StripHeaders []string `json:"strip_headers,omitempty"` + + // AnonymizeCallID replaces Call-ID with proxy-generated value + AnonymizeCallID bool `json:"anonymize_call_id,omitempty"` + + // HidePrivateIPs automatically detects and hides RFC 1918 addresses + HidePrivateIPs bool `json:"hide_private_ips,omitempty"` + + // TransactionTimeout for cleanup of pending transactions + TransactionTimeout caddy.Duration `json:"transaction_timeout,omitempty"` + + // Runtime state + logger *zap.Logger + transactions *TransactionManager + dialogs *DialogManager +} + +func (TopologyHider) CaddyModule() caddy.ModuleInfo { + return caddy.ModuleInfo{ + ID: "layer4.handlers.sip_topology_hider", + New: func() caddy.Module { return new(TopologyHider) }, + } +} + +func (h *TopologyHider) Provision(ctx caddy.Context) error { + h.logger = ctx.Logger() + + // Set defaults + if h.ProxyPort == 0 { + h.ProxyPort = 5060 + } + if h.TransactionTimeout == 0 { + h.TransactionTimeout = caddy.Duration(32 * time.Second) + } + + // Initialize transaction manager for response correlation + h.transactions = NewTransactionManager(h.logger, time.Duration(h.TransactionTimeout)) + + // Initialize dialog manager for stateful tracking + h.dialogs = NewDialogManager(h.logger, 10*time.Minute) + + // Start cleanup goroutine + go h.cleanupLoop(ctx) + + h.logger.Info("SIP Topology Hider initialized", + zap.Bool("enabled", h.Enabled), + zap.String("proxy_host", h.ProxyHost), + zap.Int("proxy_port", h.ProxyPort), + zap.Bool("rewrite_via", h.RewriteVia), + zap.Bool("rewrite_contact", h.RewriteContact), + zap.Strings("strip_headers", h.StripHeaders), + zap.Bool("anonymize_call_id", h.AnonymizeCallID), + zap.Bool("hide_private_ips", h.HidePrivateIPs), + ) + + return nil +} + +// Handle processes the SIP message and applies topology hiding +func (h *TopologyHider) Handle(cx *layer4.Connection, next layer4.Handler) error { + if !h.Enabled { + return next.Handle(cx) + } + + // Get client address + remoteAddr := cx.RemoteAddr().String() + clientHost, clientPortStr, err := net.SplitHostPort(remoteAddr) + if err != nil { + clientHost = remoteAddr + } + clientPort, _ := strconv.Atoi(clientPortStr) + + // Read the SIP message + buf := make([]byte, 8192) + n, err := cx.Read(buf) + if err != nil || n == 0 { + return next.Handle(cx) + } + buf = buf[:n] + + // Parse the SIP message + msg, err := ParseSIPMessage(buf) + if err != nil { + h.logger.Debug("Failed to parse SIP message for topology hiding", + zap.Error(err), + zap.String("client", clientHost), + ) + return next.Handle(cx) + } + + // Apply topology hiding based on message type + if msg.IsRequest { + h.handleRequest(msg, clientHost, clientPort) + } else { + h.handleResponse(msg) + } + + // Serialize modified message + modifiedData := msg.Serialize() + + // Write modified data back to connection buffer + // Note: This requires caddy-l4 support for modifying the connection data + // For now, we'll pass through the next handler + // In a full implementation, we'd use a wrapped connection + + h.logger.Debug("Topology hiding applied", + zap.Bool("is_request", msg.IsRequest), + zap.String("method", msg.Method), + zap.Int("status_code", msg.StatusCode), + zap.Int("original_size", n), + zap.Int("modified_size", len(modifiedData)), + ) + + // Continue to next handler + // TODO: Replace connection data with modified message + return next.Handle(cx) +} + +// handleRequest applies topology hiding to outgoing requests +func (h *TopologyHider) handleRequest(msg *SIPMessage, clientHost string, clientPort int) { + callID := msg.GetCallID() + fromTag := msg.GetFromTag() + + // Store original values for response correlation + originalVia := "" + if via := msg.GetHeader("Via"); via != nil { + originalVia = via.Value + } + + // Create transaction state + branch := GenerateBranch() + h.transactions.CreateTransaction(branch, msg.Method, clientHost, clientPort, originalVia) + + // For dialog-creating requests, create dialog state + if msg.IsDialogCreating() { + state := h.dialogs.CreateDialog(msg, clientHost, clientPort) + if state != nil && originalVia != "" { + h.dialogs.StoreOriginals(callID, fromTag, originalVia, "", "") + } + } + + // Add our Via header (prepend to top) + if h.RewriteVia { + newVia := h.buildViaHeader(branch) + msg.PrependHeader("Via", newVia) + + h.logger.Debug("Added Via header", + zap.String("via", newVia), + zap.String("branch", branch), + ) + } + + // Rewrite Contact header + if h.RewriteContact { + if contact := msg.GetHeader("Contact"); contact != nil { + originalContact := contact.Value + newContact := h.rewriteContactHeader(contact.Value) + contact.Value = newContact + + // Store original for response rewriting + if msg.IsDialogCreating() { + h.dialogs.StoreOriginals(callID, fromTag, "", originalContact, "") + } + + h.logger.Debug("Rewrote Contact header", + zap.String("original", originalContact), + zap.String("new", newContact), + ) + } + } + + // Anonymize Call-ID + if h.AnonymizeCallID { + if callIDHeader := msg.GetHeader("Call-ID"); callIDHeader != nil { + originalCallID := callIDHeader.Value + newCallID := h.generateAnonymousCallID() + callIDHeader.Value = newCallID + + // Store mapping for response correlation + h.dialogs.StoreOriginals(callID, fromTag, "", "", originalCallID) + + h.logger.Debug("Anonymized Call-ID", + zap.String("original", originalCallID), + zap.String("new", newCallID), + ) + } + } + + // Strip sensitive headers + h.stripSensitiveHeaders(msg) + + // Hide private IPs in other headers + if h.HidePrivateIPs { + h.hidePrivateIPsInHeaders(msg) + } +} + +// handleResponse applies topology hiding to responses +func (h *TopologyHider) handleResponse(msg *SIPMessage) { + // Find the transaction by branch + branch := msg.GetViaBranch() + transaction := h.transactions.GetTransaction(branch) + + if transaction == nil { + h.logger.Debug("No transaction found for response", + zap.String("branch", branch), + zap.Int("status_code", msg.StatusCode), + ) + return + } + + // Remove our Via header (should be top Via) + if h.RewriteVia { + // The top Via should be ours - remove it + topVia := msg.GetHeader("Via") + if topVia != nil { + viaHeader, _ := ParseViaHeader(topVia.Value) + if viaHeader != nil && viaHeader.Branch == branch { + msg.RemoveFirstHeader("Via") + h.logger.Debug("Removed proxy Via header", + zap.String("branch", branch), + ) + } + } + } + + // Update dialog state based on response + callID := msg.GetCallID() + fromTag := msg.GetFromTag() + toTag := msg.GetToTag() + + // Update dialog with To tag + if toTag != "" { + h.dialogs.UpdateDialog(callID, fromTag, toTag) + } + + // Mark dialog as confirmed for 2xx responses to INVITE + _, method := msg.GetCSeq() + if msg.StatusCode >= 200 && msg.StatusCode < 300 && method == "INVITE" { + h.dialogs.ConfirmDialog(callID, fromTag) + } + + // Strip sensitive headers from response too + h.stripSensitiveHeaders(msg) + + // Clean up transaction for final responses + if msg.StatusCode >= 200 { + h.transactions.RemoveTransaction(branch) + } +} + +// buildViaHeader creates a new Via header with proxy address +func (h *TopologyHider) buildViaHeader(branch string) string { + via := &ViaHeader{ + Protocol: "SIP/2.0", + Transport: "UDP", + Host: h.ProxyHost, + Port: h.ProxyPort, + Branch: branch, + } + return via.Serialize() +} + +// rewriteContactHeader replaces internal addresses with proxy address +func (h *TopologyHider) rewriteContactHeader(value string) string { + contact, err := ParseContactHeader(value) + if err != nil || contact.URI == "*" { + return value + } + + // Check if the Contact URI contains a private IP + host := ExtractHostFromURI(contact.URI) + if h.HidePrivateIPs && IsPrivateIP(host) { + // Rewrite URI to use proxy address + contact.URI = RewriteURIHost(contact.URI, h.ProxyHost, h.ProxyPort) + } else if h.RewriteContact { + // Always rewrite if RewriteContact is enabled + contact.URI = RewriteURIHost(contact.URI, h.ProxyHost, h.ProxyPort) + } + + return contact.Serialize() +} + +// stripSensitiveHeaders removes headers that leak internal information +func (h *TopologyHider) stripSensitiveHeaders(msg *SIPMessage) { + // Strip configured headers + for _, header := range h.StripHeaders { + msg.RemoveHeader(header) + } + + // Always strip these if HidePrivateIPs is enabled + if h.HidePrivateIPs { + for _, header := range SensitiveHeaders { + msg.RemoveHeader(header) + } + } +} + +// hidePrivateIPsInHeaders scans all headers for private IPs +func (h *TopologyHider) hidePrivateIPsInHeaders(msg *SIPMessage) { + // Headers that commonly contain IP addresses + ipHeaders := []string{ + "Record-Route", + "Route", + "P-Visited-Network-ID", + } + + for i := range msg.Headers { + headerLower := strings.ToLower(msg.Headers[i].Name) + + for _, ipHeader := range ipHeaders { + if headerLower == strings.ToLower(ipHeader) { + // Check for private IPs in the value + if containsPrivateIP(msg.Headers[i].Value) { + // Rewrite or remove the header + msg.Headers[i].Value = rewritePrivateIPs(msg.Headers[i].Value, h.ProxyHost) + } + } + } + } +} + +// generateAnonymousCallID creates a new random Call-ID +func (h *TopologyHider) generateAnonymousCallID() string { + bytes := make([]byte, 16) + rand.Read(bytes) + return hex.EncodeToString(bytes) + "@" + h.ProxyHost +} + +// containsPrivateIP checks if a string contains any private IP addresses +func containsPrivateIP(value string) bool { + // Simple pattern matching for common private IP formats + privatePatterns := []string{ + "10.", "192.168.", "172.16.", "172.17.", "172.18.", "172.19.", + "172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.", + "172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.", + } + + for _, pattern := range privatePatterns { + if strings.Contains(value, pattern) { + return true + } + } + return false +} + +// rewritePrivateIPs replaces private IP addresses in a header value +func rewritePrivateIPs(value, replacement string) string { + // This is a simplified implementation + // A full implementation would use regex to properly replace IPs in URIs + + // For Record-Route and Route, we may want to preserve the structure + // but replace the host portion + + return value // TODO: Implement proper IP replacement +} + +// cleanupLoop periodically cleans up expired state +func (h *TopologyHider) cleanupLoop(ctx caddy.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + transRemoved := h.transactions.Cleanup() + dialogRemoved := h.dialogs.Cleanup() + + if transRemoved > 0 || dialogRemoved > 0 { + h.logger.Debug("Topology hider cleanup", + zap.Int("transactions_removed", transRemoved), + zap.Int("dialogs_removed", dialogRemoved), + ) + } + } + } +} + +// GetStats returns statistics about the topology hider +func (h *TopologyHider) GetStats() map[string]interface{} { + stats := map[string]interface{}{ + "enabled": h.Enabled, + "proxy_host": h.ProxyHost, + "proxy_port": h.ProxyPort, + "rewrite_via": h.RewriteVia, + "rewrite_contact": h.RewriteContact, + "anonymize_callid": h.AnonymizeCallID, + "hide_private_ips": h.HidePrivateIPs, + } + + if h.transactions != nil { + for k, v := range h.transactions.GetStats() { + stats["transactions_"+k] = v + } + } + + if h.dialogs != nil { + for k, v := range h.dialogs.GetStats() { + stats["dialogs_"+k] = v + } + } + + return stats +} + +// UnmarshalCaddyfile implements caddyfile.Unmarshaler for TopologyHider. +// Usage in Caddyfile: +// +// sip_topology_hider { +// enabled true +// proxy_host 203.0.113.1 +// proxy_port 5060 +// rewrite_via +// rewrite_contact +// strip_headers P-Preferred-Identity P-Asserted-Identity Server +// anonymize_call_id +// hide_private_ips +// transaction_timeout 32s +// } +func (h *TopologyHider) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { + // Move past "sip_topology_hider" token + d.Next() + + for nesting := d.Nesting(); d.NextBlock(nesting); { + switch d.Val() { + case "enabled": + if !d.NextArg() { + // No argument means enabled + h.Enabled = true + } else { + val := d.Val() + h.Enabled = val == "true" || val == "yes" || val == "on" + } + + case "proxy_host": + if !d.NextArg() { + return d.ArgErr() + } + h.ProxyHost = d.Val() + + case "proxy_port": + if !d.NextArg() { + return d.ArgErr() + } + port, err := strconv.Atoi(d.Val()) + if err != nil { + return d.Errf("invalid proxy_port: %v", err) + } + h.ProxyPort = port + + case "rewrite_via": + h.RewriteVia = true + + case "rewrite_contact": + h.RewriteContact = true + + case "strip_headers": + h.StripHeaders = d.RemainingArgs() + if len(h.StripHeaders) == 0 { + return d.ArgErr() + } + + case "anonymize_call_id": + h.AnonymizeCallID = true + + case "hide_private_ips": + h.HidePrivateIPs = true + + case "transaction_timeout": + if !d.NextArg() { + return d.ArgErr() + } + dur, err := caddy.ParseDuration(d.Val()) + if err != nil { + return d.Errf("invalid transaction_timeout: %v", err) + } + h.TransactionTimeout = caddy.Duration(dur) + + default: + return d.Errf("unknown sip_topology_hider directive: %s", d.Val()) + } + } + + // Validate required fields + if h.Enabled && h.ProxyHost == "" { + return fmt.Errorf("proxy_host is required when topology hiding is enabled") + } + + return nil +} + +// Interface guards +var ( + _ layer4.NextHandler = (*TopologyHider)(nil) + _ caddy.Module = (*TopologyHider)(nil) + _ caddy.Provisioner = (*TopologyHider)(nil) + _ caddyfile.Unmarshaler = (*TopologyHider)(nil) +) diff --git a/topology_test.go b/topology_test.go new file mode 100644 index 0000000..5b93629 --- /dev/null +++ b/topology_test.go @@ -0,0 +1,833 @@ +package sipguardian + +import ( + "strings" + "testing" + "time" + + "go.uber.org/zap" +) + +// ============================================================================= +// SIP Message Parsing Tests (sipmsg.go) +// ============================================================================= + +func TestParseSIPRequest(t *testing.T) { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds\r\n" + + "From: \"Alice\" ;tag=1928301774\r\n" + + "To: \r\n" + + "Call-ID: a84b4c76e66710@pc33.example.com\r\n" + + "CSeq: 314159 INVITE\r\n" + + "Contact: \r\n" + + "Content-Type: application/sdp\r\n" + + "Content-Length: 4\r\n" + + "\r\n" + + "test" + + msg, err := ParseSIPMessage([]byte(raw)) + if err != nil { + t.Fatalf("Failed to parse SIP request: %v", err) + } + + if !msg.IsRequest { + t.Error("Expected IsRequest to be true") + } + if msg.Method != "INVITE" { + t.Errorf("Expected method INVITE, got %s", msg.Method) + } + if msg.RequestURI != "sip:bob@example.com" { + t.Errorf("Expected RequestURI sip:bob@example.com, got %s", msg.RequestURI) + } + if msg.SIPVersion != "SIP/2.0" { + t.Errorf("Expected SIP/2.0, got %s", msg.SIPVersion) + } + if string(msg.Body) != "test" { + t.Errorf("Expected body 'test', got '%s'", string(msg.Body)) + } +} + +func TestParseSIPResponse(t *testing.T) { + raw := "SIP/2.0 200 OK\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds;received=10.0.0.1\r\n" + + "From: \"Alice\" ;tag=1928301774\r\n" + + "To: ;tag=a6c85cf\r\n" + + "Call-ID: a84b4c76e66710@pc33.example.com\r\n" + + "CSeq: 314159 INVITE\r\n" + + "Contact: \r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, err := ParseSIPMessage([]byte(raw)) + if err != nil { + t.Fatalf("Failed to parse SIP response: %v", err) + } + + if msg.IsRequest { + t.Error("Expected IsRequest to be false for response") + } + if msg.StatusCode != 200 { + t.Errorf("Expected status code 200, got %d", msg.StatusCode) + } + if msg.ReasonPhrase != "OK" { + t.Errorf("Expected reason phrase 'OK', got '%s'", msg.ReasonPhrase) + } +} + +func TestSIPMessageGetHeader(t *testing.T) { + raw := "REGISTER sip:example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test123@host\r\n" + + "CSeq: 1 REGISTER\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + + // Test case-insensitive lookup + via := msg.GetHeader("via") + if via == nil { + t.Error("Failed to get Via header (lowercase)") + } + + Via := msg.GetHeader("Via") + if Via == nil { + t.Error("Failed to get Via header (mixed case)") + } + + // Test GetCallID helper + callID := msg.GetCallID() + if callID != "test123@host" { + t.Errorf("Expected Call-ID 'test123@host', got '%s'", callID) + } + + // Test GetFromTag helper + fromTag := msg.GetFromTag() + if fromTag != "123" { + t.Errorf("Expected From tag '123', got '%s'", fromTag) + } +} + +func TestSIPMessageCompactHeaders(t *testing.T) { + // Test compact header forms (RFC 3261 Section 7.3.3) + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "v: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" + + "f: ;tag=123\r\n" + + "t: \r\n" + + "i: compact-call-id@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "m: \r\n" + + "l: 0\r\n" + + "\r\n" + + msg, err := ParseSIPMessage([]byte(raw)) + if err != nil { + t.Fatalf("Failed to parse message with compact headers: %v", err) + } + + // Via (v) + if msg.GetHeader("Via") == nil { + t.Error("Failed to get Via via compact form 'v'") + } + + // Call-ID (i) + if msg.GetCallID() != "compact-call-id@host" { + t.Error("Failed to get Call-ID via compact form 'i'") + } + + // Contact (m) + if msg.GetHeader("Contact") == nil { + t.Error("Failed to get Contact via compact form 'm'") + } +} + +func TestSIPMessageSerialize(t *testing.T) { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + serialized := msg.Serialize() + + // Parse again and verify round-trip + msg2, err := ParseSIPMessage(serialized) + if err != nil { + t.Fatalf("Failed to parse serialized message: %v", err) + } + + if msg.Method != msg2.Method { + t.Errorf("Method mismatch after round-trip: %s vs %s", msg.Method, msg2.Method) + } + if msg.GetCallID() != msg2.GetCallID() { + t.Errorf("Call-ID mismatch after round-trip") + } +} + +func TestSIPMessageHeaderManipulation(t *testing.T) { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + + // Test PrependHeader (for Via insertion) + msg.PrependHeader("Via", "SIP/2.0/UDP proxy.example.com:5060;branch=z9hG4bKproxy") + + vias := msg.GetHeaders("Via") + if len(vias) != 2 { + t.Errorf("Expected 2 Via headers after prepend, got %d", len(vias)) + } + if !strings.Contains(vias[0].Value, "proxy.example.com") { + t.Error("Prepended Via should be first") + } + + // Test RemoveFirstHeader + msg.RemoveFirstHeader("Via") + vias = msg.GetHeaders("Via") + if len(vias) != 1 { + t.Errorf("Expected 1 Via header after removal, got %d", len(vias)) + } + + // Test SetHeader + msg.SetHeader("User-Agent", "Test/1.0") + ua := msg.GetHeader("User-Agent") + if ua == nil || ua.Value != "Test/1.0" { + t.Error("SetHeader failed") + } + + // Test RemoveHeader + msg.RemoveHeader("User-Agent") + if msg.GetHeader("User-Agent") != nil { + t.Error("RemoveHeader failed") + } +} + +func TestSIPMessageClone(t *testing.T) { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 4\r\n" + + "\r\n" + + "test" + + msg, _ := ParseSIPMessage([]byte(raw)) + clone := msg.Clone() + + // Modify original + msg.SetHeader("Via", "modified") + msg.Body = []byte("modified") + + // Clone should be unchanged + if clone.GetHeader("Via").Value == "modified" { + t.Error("Clone was affected by original modification") + } + if string(clone.Body) != "test" { + t.Error("Clone body was affected by original modification") + } +} + +// ============================================================================= +// Header Parsing Tests (sipheaders.go) +// ============================================================================= + +func TestParseViaHeader(t *testing.T) { + tests := []struct { + name string + input string + expected ViaHeader + }{ + { + name: "basic UDP", + input: "SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds", + expected: ViaHeader{ + Protocol: "SIP/2.0", + Transport: "UDP", + Host: "192.168.1.100", + Port: 5060, + Branch: "z9hG4bK776asdhds", + }, + }, + { + name: "with received and rport", + input: "SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776;received=10.0.0.1;rport=12345", + expected: ViaHeader{ + Protocol: "SIP/2.0", + Transport: "UDP", + Host: "192.168.1.100", + Port: 5060, + Branch: "z9hG4bK776", + Received: "10.0.0.1", + RPort: 12345, + }, + }, + { + name: "TCP transport", + input: "SIP/2.0/TCP proxy.example.com:5060;branch=z9hG4bKabc", + expected: ViaHeader{ + Protocol: "SIP/2.0", + Transport: "TCP", + Host: "proxy.example.com", + Port: 5060, + Branch: "z9hG4bKabc", + }, + }, + { + name: "default port", + input: "SIP/2.0/UDP 192.168.1.100;branch=z9hG4bK123", + expected: ViaHeader{ + Protocol: "SIP/2.0", + Transport: "UDP", + Host: "192.168.1.100", + Port: 5060, // Default + Branch: "z9hG4bK123", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + via, err := ParseViaHeader(tt.input) + if err != nil { + t.Fatalf("ParseViaHeader failed: %v", err) + } + + if via.Protocol != tt.expected.Protocol { + t.Errorf("Protocol: got %s, want %s", via.Protocol, tt.expected.Protocol) + } + if via.Transport != tt.expected.Transport { + t.Errorf("Transport: got %s, want %s", via.Transport, tt.expected.Transport) + } + if via.Host != tt.expected.Host { + t.Errorf("Host: got %s, want %s", via.Host, tt.expected.Host) + } + if via.Port != tt.expected.Port { + t.Errorf("Port: got %d, want %d", via.Port, tt.expected.Port) + } + if via.Branch != tt.expected.Branch { + t.Errorf("Branch: got %s, want %s", via.Branch, tt.expected.Branch) + } + }) + } +} + +func TestViaHeaderSerialize(t *testing.T) { + via := &ViaHeader{ + Protocol: "SIP/2.0", + Transport: "UDP", + Host: "192.168.1.100", + Port: 5060, + Branch: "z9hG4bK776", + Received: "10.0.0.1", + RPort: 12345, + } + + serialized := via.Serialize() + + // Parse it back + via2, err := ParseViaHeader(serialized) + if err != nil { + t.Fatalf("Failed to parse serialized Via: %v", err) + } + + if via2.Host != via.Host || via2.Branch != via.Branch || via2.Received != via.Received { + t.Error("Via header round-trip failed") + } +} + +func TestParseContactHeader(t *testing.T) { + tests := []struct { + name string + input string + uri string + dname string + }{ + { + name: "simple URI", + input: "", + uri: "sip:alice@192.168.1.100:5060", + dname: "", + }, + { + name: "with display name", + input: "\"Alice\" ;expires=3600", + uri: "sip:alice@example.com", + dname: "Alice", + }, + { + name: "star contact (REGISTER)", + input: "*", + uri: "*", + dname: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + contact, err := ParseContactHeader(tt.input) + if err != nil { + t.Fatalf("ParseContactHeader failed: %v", err) + } + + if contact.URI != tt.uri { + t.Errorf("URI: got %s, want %s", contact.URI, tt.uri) + } + if contact.DisplayName != tt.dname { + t.Errorf("DisplayName: got %s, want %s", contact.DisplayName, tt.dname) + } + }) + } +} + +func TestParseFromToHeader(t *testing.T) { + tests := []struct { + name string + input string + uri string + tag string + dname string + }{ + { + name: "From with tag", + input: "\"Bob\" ;tag=a6c85cf", + uri: "sip:bob@example.com", + tag: "a6c85cf", + dname: "Bob", + }, + { + name: "To without tag", + input: "", + uri: "sip:alice@example.com", + tag: "", + dname: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + header, err := ParseFromToHeader(tt.input) + if err != nil { + t.Fatalf("ParseFromToHeader failed: %v", err) + } + + if header.URI != tt.uri { + t.Errorf("URI: got %s, want %s", header.URI, tt.uri) + } + if header.Tag != tt.tag { + t.Errorf("Tag: got %s, want %s", header.Tag, tt.tag) + } + if header.DisplayName != tt.dname { + t.Errorf("DisplayName: got %s, want %s", header.DisplayName, tt.dname) + } + }) + } +} + +func TestExtractHostFromURI(t *testing.T) { + tests := []struct { + uri string + expected string + }{ + {"sip:alice@example.com", "example.com"}, + {"sip:bob@192.168.1.100:5060", "192.168.1.100"}, + {"sips:secure@tls.example.com:5061;transport=tls", "tls.example.com"}, + {"sip:user@host;param=value", "host"}, + } + + for _, tt := range tests { + t.Run(tt.uri, func(t *testing.T) { + host := ExtractHostFromURI(tt.uri) + if host != tt.expected { + t.Errorf("ExtractHostFromURI(%s) = %s, want %s", tt.uri, host, tt.expected) + } + }) + } +} + +func TestRewriteURIHost(t *testing.T) { + tests := []struct { + uri string + newHost string + newPort int + expected string + }{ + { + uri: "sip:alice@192.168.1.100:5060", + newHost: "proxy.example.com", + newPort: 5060, + expected: "sip:alice@proxy.example.com", + }, + { + uri: "sip:bob@internal.lan:5060;transport=udp", + newHost: "external.example.com", + newPort: 5080, + expected: "sip:bob@external.example.com:5080;transport=udp", + }, + } + + for _, tt := range tests { + t.Run(tt.uri, func(t *testing.T) { + result := RewriteURIHost(tt.uri, tt.newHost, tt.newPort) + if result != tt.expected { + t.Errorf("RewriteURIHost(%s, %s, %d) = %s, want %s", + tt.uri, tt.newHost, tt.newPort, result, tt.expected) + } + }) + } +} + +func TestGenerateBranch(t *testing.T) { + branch := GenerateBranch() + + // RFC 3261: Branch must start with "z9hG4bK" + if !strings.HasPrefix(branch, "z9hG4bK") { + t.Errorf("Branch %s does not start with z9hG4bK", branch) + } + + // Should be unique + branch2 := GenerateBranch() + if branch == branch2 { + t.Error("GenerateBranch should generate unique values") + } +} + +// ============================================================================= +// Dialog State Tests (dialog_state.go) +// ============================================================================= + +func TestDialogManagerCreateAndGet(t *testing.T) { + logger := zap.NewNop() + dm := NewDialogManager(logger, 5*time.Minute) + + // Create a mock SIP message + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776test\r\n" + + "From: ;tag=fromtag123\r\n" + + "To: \r\n" + + "Call-ID: testcall@example.com\r\n" + + "CSeq: 1 INVITE\r\n" + + "Contact: \r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + + // Create dialog + state := dm.CreateDialog(msg, "192.168.1.100", 5060) + + if state == nil { + t.Fatal("CreateDialog returned nil") + } + if state.CallID != "testcall@example.com" { + t.Errorf("CallID mismatch: got %s", state.CallID) + } + if state.FromTag != "fromtag123" { + t.Errorf("FromTag mismatch: got %s", state.FromTag) + } + + // Get by Call-ID + retrieved := dm.GetDialogByCallID("testcall@example.com", "fromtag123") + if retrieved == nil { + t.Fatal("GetDialogByCallID returned nil") + } + if retrieved != state { + t.Error("Retrieved dialog doesn't match created dialog") + } + + // Get by branch + byBranch := dm.GetDialogByBranch("z9hG4bK776test") + if byBranch == nil { + t.Fatal("GetDialogByBranch returned nil") + } + if byBranch != state { + t.Error("Retrieved dialog by branch doesn't match") + } +} + +func TestDialogManagerUpdateAndConfirm(t *testing.T) { + logger := zap.NewNop() + dm := NewDialogManager(logger, 5*time.Minute) + + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKupdate\r\n" + + "From: ;tag=from456\r\n" + + "To: \r\n" + + "Call-ID: updatetest@example.com\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + dm.CreateDialog(msg, "192.168.1.100", 5060) + + // Update with To-tag (from 200 OK) + dm.UpdateDialog("updatetest@example.com", "from456", "totag789") + + state := dm.GetDialogByCallID("updatetest@example.com", "from456") + if state.ToTag != "totag789" { + t.Errorf("ToTag not updated: got %s", state.ToTag) + } + + // Confirm dialog + dm.ConfirmDialog("updatetest@example.com", "from456") + if !state.IsConfirmed { + t.Error("Dialog not confirmed") + } + + // Terminate dialog + dm.TerminateDialog("updatetest@example.com", "from456") + if !state.IsTerminated { + t.Error("Dialog not terminated") + } +} + +func TestDialogManagerRemove(t *testing.T) { + logger := zap.NewNop() + dm := NewDialogManager(logger, 5*time.Minute) + + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKremove\r\n" + + "From: ;tag=removetag\r\n" + + "To: \r\n" + + "Call-ID: removetest@example.com\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + dm.CreateDialog(msg, "192.168.1.100", 5060) + + // Remove dialog + dm.RemoveDialog("removetest@example.com", "removetag") + + // Should be gone + if dm.GetDialogByCallID("removetest@example.com", "removetag") != nil { + t.Error("Dialog not removed from call-id index") + } + if dm.GetDialogByBranch("z9hG4bKremove") != nil { + t.Error("Dialog not removed from branch index") + } +} + +func TestDialogManagerStats(t *testing.T) { + logger := zap.NewNop() + dm := NewDialogManager(logger, 5*time.Minute) + + // Create multiple dialogs in different states + for i := 0; i < 3; i++ { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKstats" + string(rune('a'+i)) + "\r\n" + + "From: ;tag=statstag" + string(rune('a'+i)) + "\r\n" + + "To: \r\n" + + "Call-ID: statstest" + string(rune('a'+i)) + "@example.com\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + dm.CreateDialog(msg, "192.168.1.100", 5060) + } + + // Confirm one + dm.ConfirmDialog("statstesta@example.com", "statstaga") + + // Terminate one + dm.TerminateDialog("statstestb@example.com", "statstagb") + + stats := dm.GetStats() + if stats["total_dialogs"].(int) != 3 { + t.Errorf("Expected 3 total dialogs, got %v", stats["total_dialogs"]) + } + if stats["confirmed_dialogs"].(int) != 1 { + t.Errorf("Expected 1 confirmed dialog, got %v", stats["confirmed_dialogs"]) + } + if stats["terminated_dialogs"].(int) != 1 { + t.Errorf("Expected 1 terminated dialog, got %v", stats["terminated_dialogs"]) + } +} + +func TestTransactionManager(t *testing.T) { + logger := zap.NewNop() + tm := NewTransactionManager(logger, 32*time.Second) + + // Create transaction + state := tm.CreateTransaction("z9hG4bKtx123", "INVITE", "192.168.1.100", 5060, "SIP/2.0/UDP original:5060") + + if state == nil { + t.Fatal("CreateTransaction returned nil") + } + + // Retrieve + retrieved := tm.GetTransaction("z9hG4bKtx123") + if retrieved == nil { + t.Fatal("GetTransaction returned nil") + } + if retrieved.Method != "INVITE" { + t.Errorf("Method mismatch: got %s", retrieved.Method) + } + + // Remove + tm.RemoveTransaction("z9hG4bKtx123") + if tm.GetTransaction("z9hG4bKtx123") != nil { + t.Error("Transaction not removed") + } +} + +// ============================================================================= +// Topology Hider Tests (topology.go) +// ============================================================================= + +func TestTopologyHiderConfig(t *testing.T) { + th := &TopologyHider{ + Enabled: true, + ProxyHost: "proxy.example.com", + ProxyPort: 5060, + RewriteVia: true, + RewriteContact: true, + AnonymizeCallID: false, + HidePrivateIPs: true, + StripHeaders: []string{"Server", "User-Agent"}, + } + + if !th.Enabled { + t.Error("Expected Enabled to be true") + } + if len(th.StripHeaders) != 2 { + t.Errorf("Expected 2 strip headers, got %d", len(th.StripHeaders)) + } +} + +func TestIsDialogCreating(t *testing.T) { + tests := []struct { + method string + expected bool + }{ + {"INVITE", true}, + {"SUBSCRIBE", true}, + {"REFER", true}, + {"NOTIFY", true}, + {"REGISTER", false}, + {"OPTIONS", false}, + {"BYE", false}, + {"CANCEL", false}, + } + + for _, tt := range tests { + t.Run(tt.method, func(t *testing.T) { + raw := tt.method + " sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 " + tt.method + "\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + if msg.IsDialogCreating() != tt.expected { + t.Errorf("%s: IsDialogCreating() = %v, want %v", + tt.method, msg.IsDialogCreating(), tt.expected) + } + }) + } +} + +func TestIsDialogTerminating(t *testing.T) { + byeMsg := "BYE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" + + "From: ;tag=123\r\n" + + "To: ;tag=456\r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 2 BYE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(byeMsg)) + if !msg.IsDialogTerminating() { + t.Error("BYE should be dialog terminating") + } + + inviteMsg := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg2, _ := ParseSIPMessage([]byte(inviteMsg)) + if msg2.IsDialogTerminating() { + t.Error("INVITE should not be dialog terminating") + } +} + +func TestSensitiveHeadersList(t *testing.T) { + expected := []string{ + "P-Preferred-Identity", + "P-Asserted-Identity", + "Remote-Party-ID", + "P-Charging-Vector", + "P-Charging-Function-Addresses", + "Server", + "User-Agent", + "X-Asterisk-HangupCause", + "X-Asterisk-HangupCauseCode", + } + + if len(SensitiveHeaders) != len(expected) { + t.Errorf("Expected %d sensitive headers, got %d", len(expected), len(SensitiveHeaders)) + } + + for _, h := range expected { + found := false + for _, sh := range SensitiveHeaders { + if sh == h { + found = true + break + } + } + if !found { + t.Errorf("Missing sensitive header: %s", h) + } + } +} + +func TestUpdateContentLength(t *testing.T) { + raw := "INVITE sip:bob@example.com SIP/2.0\r\n" + + "Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" + + "From: ;tag=123\r\n" + + "To: \r\n" + + "Call-ID: test@host\r\n" + + "CSeq: 1 INVITE\r\n" + + "Content-Length: 0\r\n" + + "\r\n" + + msg, _ := ParseSIPMessage([]byte(raw)) + + // Set a body + msg.Body = []byte("v=0\r\no=- 12345 12345 IN IP4 192.168.1.100\r\ns=-\r\n") + msg.UpdateContentLength() + + cl := msg.GetHeader("Content-Length") + if cl == nil { + t.Fatal("Content-Length header missing") + } + if cl.Value != "48" { + t.Errorf("Content-Length should be 48, got %s", cl.Value) + } +}