Add SIP topology hiding feature (B2BUA-lite)

Implements RFC 3261 compliant topology hiding to protect internal
infrastructure from external attackers:

New files:
- sipmsg.go: SIP message parsing/serialization with full header support
- sipheaders.go: Via, Contact, From/To header parsing with compact forms
- dialog_state.go: Dialog and transaction state management for response correlation
- topology.go: TopologyHider handler for caddy-l4 integration
- topology_test.go: Comprehensive unit tests (26 new tests, 60 total)

Features:
- Via header insertion (proxy adds own Via, pops on response)
- Contact header rewriting (hide internal IPs behind proxy address)
- Sensitive header stripping (P-Asserted-Identity, Server, etc.)
- Call-ID anonymization (optional)
- Private IP masking in all headers
- Dialog state tracking for stateful response routing
- Transaction state for stateless operation

Caddyfile configuration:
  sip_topology_hider {
    proxy_host 203.0.113.1
    proxy_port 5060
    upstream udp/192.168.1.100:5060
    rewrite_via
    rewrite_contact
    strip_headers P-Preferred-Identity Server
  }
This commit is contained in:
Ryan Malloy 2025-12-07 19:02:50 -07:00
parent 976fdf53a5
commit f76946fc41
5 changed files with 2645 additions and 0 deletions

391
dialog_state.go Normal file
View File

@ -0,0 +1,391 @@
package sipguardian
import (
"sync"
"time"
"go.uber.org/zap"
)
// DialogState tracks the state of a SIP dialog for topology hiding
type DialogState struct {
// Dialog identifiers
CallID string
FromTag string
ToTag string
LocalBranch string // Via branch we added
// Original values (for response rewriting)
OriginalCallID string // if anonymized
OriginalVia string // original top Via header
OriginalContact string // original Contact header
// Proxy values (what we replaced with)
ProxyContact string
ProxyVia string
// Client information (for routing responses)
ClientHost string
ClientPort int
// Timestamps
Created time.Time
LastActivity time.Time
// State tracking
IsConfirmed bool // 2xx received for dialog-creating request
IsTerminated bool
}
// DialogManager manages dialog state for topology hiding
type DialogManager struct {
dialogs map[string]*DialogState
byBranch map[string]*DialogState // Quick lookup by Via branch
cleanupTTL time.Duration
logger *zap.Logger
mu sync.RWMutex
}
// NewDialogManager creates a new dialog manager
func NewDialogManager(logger *zap.Logger, cleanupTTL time.Duration) *DialogManager {
if cleanupTTL == 0 {
cleanupTTL = 10 * time.Minute
}
dm := &DialogManager{
dialogs: make(map[string]*DialogState),
byBranch: make(map[string]*DialogState),
cleanupTTL: cleanupTTL,
logger: logger,
}
return dm
}
// dialogKey generates a unique key for a dialog
func dialogKey(callID, fromTag string) string {
return callID + ":" + fromTag
}
// CreateDialog creates a new dialog state for an outgoing request
func (dm *DialogManager) CreateDialog(msg *SIPMessage, clientHost string, clientPort int) *DialogState {
dm.mu.Lock()
defer dm.mu.Unlock()
callID := msg.GetCallID()
fromTag := msg.GetFromTag()
branch := msg.GetViaBranch()
state := &DialogState{
CallID: callID,
FromTag: fromTag,
LocalBranch: branch,
ClientHost: clientHost,
ClientPort: clientPort,
Created: time.Now(),
LastActivity: time.Now(),
}
// Store original Via if we're going to modify it
if via := msg.GetHeader("Via"); via != nil {
state.OriginalVia = via.Value
}
// Store original Contact
if contact := msg.GetHeader("Contact"); contact != nil {
state.OriginalContact = contact.Value
}
// Store by dialog key and by branch
key := dialogKey(callID, fromTag)
dm.dialogs[key] = state
if branch != "" {
dm.byBranch[branch] = state
}
dm.logger.Debug("Dialog created",
zap.String("call_id", callID),
zap.String("from_tag", fromTag),
zap.String("branch", branch),
)
return state
}
// GetDialogByCallID retrieves a dialog by Call-ID and From-tag
func (dm *DialogManager) GetDialogByCallID(callID, fromTag string) *DialogState {
dm.mu.RLock()
defer dm.mu.RUnlock()
key := dialogKey(callID, fromTag)
return dm.dialogs[key]
}
// GetDialogByBranch retrieves a dialog by Via branch (for response routing)
func (dm *DialogManager) GetDialogByBranch(branch string) *DialogState {
dm.mu.RLock()
defer dm.mu.RUnlock()
return dm.byBranch[branch]
}
// UpdateDialog updates an existing dialog with new information
func (dm *DialogManager) UpdateDialog(callID, fromTag string, toTag string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
state.LastActivity = time.Now()
if toTag != "" && state.ToTag == "" {
state.ToTag = toTag
}
}
}
// ConfirmDialog marks a dialog as confirmed (2xx received)
func (dm *DialogManager) ConfirmDialog(callID, fromTag string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
state.IsConfirmed = true
state.LastActivity = time.Now()
}
}
// TerminateDialog marks a dialog as terminated
func (dm *DialogManager) TerminateDialog(callID, fromTag string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
state.IsTerminated = true
state.LastActivity = time.Now()
}
}
// RemoveDialog removes a dialog from the manager
func (dm *DialogManager) RemoveDialog(callID, fromTag string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
// Remove from branch index
if state.LocalBranch != "" {
delete(dm.byBranch, state.LocalBranch)
}
delete(dm.dialogs, key)
dm.logger.Debug("Dialog removed",
zap.String("call_id", callID),
zap.String("from_tag", fromTag),
)
}
}
// StoreOriginals stores original header values before rewriting
func (dm *DialogManager) StoreOriginals(callID, fromTag string, via, contact string, originalCallID string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
if via != "" {
state.OriginalVia = via
}
if contact != "" {
state.OriginalContact = contact
}
if originalCallID != "" {
state.OriginalCallID = originalCallID
}
}
}
// StoreProxyValues stores the values we used to replace originals
func (dm *DialogManager) StoreProxyValues(callID, fromTag string, via, contact string) {
dm.mu.Lock()
defer dm.mu.Unlock()
key := dialogKey(callID, fromTag)
if state, ok := dm.dialogs[key]; ok {
if via != "" {
state.ProxyVia = via
}
if contact != "" {
state.ProxyContact = contact
}
}
}
// Cleanup removes stale dialogs
func (dm *DialogManager) Cleanup() int {
dm.mu.Lock()
defer dm.mu.Unlock()
cutoff := time.Now().Add(-dm.cleanupTTL)
removed := 0
for key, state := range dm.dialogs {
// Remove terminated dialogs after TTL
if state.IsTerminated && state.LastActivity.Before(cutoff) {
if state.LocalBranch != "" {
delete(dm.byBranch, state.LocalBranch)
}
delete(dm.dialogs, key)
removed++
continue
}
// Remove unconfirmed dialogs that have been idle too long
if !state.IsConfirmed && state.LastActivity.Before(cutoff) {
if state.LocalBranch != "" {
delete(dm.byBranch, state.LocalBranch)
}
delete(dm.dialogs, key)
removed++
continue
}
// Remove very old dialogs regardless of state (prevent memory leak)
if state.Created.Before(time.Now().Add(-24 * time.Hour)) {
if state.LocalBranch != "" {
delete(dm.byBranch, state.LocalBranch)
}
delete(dm.dialogs, key)
removed++
}
}
if removed > 0 {
dm.logger.Debug("Cleaned up stale dialogs", zap.Int("count", removed))
}
return removed
}
// GetStats returns statistics about the dialog manager
func (dm *DialogManager) GetStats() map[string]interface{} {
dm.mu.RLock()
defer dm.mu.RUnlock()
confirmed := 0
terminated := 0
pending := 0
for _, state := range dm.dialogs {
if state.IsTerminated {
terminated++
} else if state.IsConfirmed {
confirmed++
} else {
pending++
}
}
return map[string]interface{}{
"total_dialogs": len(dm.dialogs),
"confirmed_dialogs": confirmed,
"terminated_dialogs": terminated,
"pending_dialogs": pending,
"branch_index_size": len(dm.byBranch),
}
}
// TransactionState tracks a single SIP transaction (simpler than full dialog)
type TransactionState struct {
Branch string
Method string
ClientHost string
ClientPort int
OriginalVia string
Created time.Time
}
// TransactionManager manages transaction state for stateless topology hiding
type TransactionManager struct {
transactions map[string]*TransactionState
cleanupTTL time.Duration
logger *zap.Logger
mu sync.RWMutex
}
// NewTransactionManager creates a new transaction manager
func NewTransactionManager(logger *zap.Logger, cleanupTTL time.Duration) *TransactionManager {
if cleanupTTL == 0 {
cleanupTTL = 32 * time.Second // SIP transaction timeout
}
return &TransactionManager{
transactions: make(map[string]*TransactionState),
cleanupTTL: cleanupTTL,
logger: logger,
}
}
// CreateTransaction creates a new transaction state
func (tm *TransactionManager) CreateTransaction(branch, method, clientHost string, clientPort int, originalVia string) *TransactionState {
tm.mu.Lock()
defer tm.mu.Unlock()
state := &TransactionState{
Branch: branch,
Method: method,
ClientHost: clientHost,
ClientPort: clientPort,
OriginalVia: originalVia,
Created: time.Now(),
}
tm.transactions[branch] = state
return state
}
// GetTransaction retrieves a transaction by branch
func (tm *TransactionManager) GetTransaction(branch string) *TransactionState {
tm.mu.RLock()
defer tm.mu.RUnlock()
return tm.transactions[branch]
}
// RemoveTransaction removes a transaction
func (tm *TransactionManager) RemoveTransaction(branch string) {
tm.mu.Lock()
defer tm.mu.Unlock()
delete(tm.transactions, branch)
}
// Cleanup removes expired transactions
func (tm *TransactionManager) Cleanup() int {
tm.mu.Lock()
defer tm.mu.Unlock()
cutoff := time.Now().Add(-tm.cleanupTTL)
removed := 0
for branch, state := range tm.transactions {
if state.Created.Before(cutoff) {
delete(tm.transactions, branch)
removed++
}
}
return removed
}
// GetStats returns statistics about the transaction manager
func (tm *TransactionManager) GetStats() map[string]interface{} {
tm.mu.RLock()
defer tm.mu.RUnlock()
return map[string]interface{}{
"active_transactions": len(tm.transactions),
}
}

466
sipheaders.go Normal file
View File

@ -0,0 +1,466 @@
package sipguardian
import (
"fmt"
"regexp"
"strconv"
"strings"
)
// ViaHeader represents a parsed Via header
type ViaHeader struct {
Protocol string // SIP/2.0
Transport string // UDP, TCP, TLS, WS, WSS
Host string
Port int
Branch string
Received string
RPort int
Params map[string]string // Other parameters
}
// ContactHeader represents a parsed Contact header
type ContactHeader struct {
DisplayName string
URI string
Params map[string]string
}
// FromToHeader represents a parsed From or To header
type FromToHeader struct {
DisplayName string
URI string
Tag string
Params map[string]string
}
// SensitiveHeaders lists headers that may leak internal topology
var SensitiveHeaders = []string{
"P-Preferred-Identity",
"P-Asserted-Identity",
"Remote-Party-ID",
"P-Charging-Vector",
"P-Charging-Function-Addresses",
"Server",
"User-Agent",
"X-Asterisk-HangupCause",
"X-Asterisk-HangupCauseCode",
}
// ParseViaHeader parses a Via header value into structured form
// Example: "SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123;received=10.0.0.1"
func ParseViaHeader(value string) (*ViaHeader, error) {
via := &ViaHeader{
Params: make(map[string]string),
}
// Split by semicolons to get parameters
parts := strings.Split(value, ";")
if len(parts) == 0 {
return nil, fmt.Errorf("empty Via header")
}
// First part: SIP/2.0/UDP host:port
mainPart := strings.TrimSpace(parts[0])
fields := strings.Fields(mainPart)
if len(fields) < 2 {
return nil, fmt.Errorf("invalid Via header format")
}
// Parse protocol/transport
protoTransport := strings.Split(fields[0], "/")
if len(protoTransport) >= 2 {
via.Protocol = protoTransport[0] + "/" + protoTransport[1]
}
if len(protoTransport) >= 3 {
via.Transport = protoTransport[2]
}
// Parse host:port
hostPort := fields[1]
if colonIdx := strings.LastIndex(hostPort, ":"); colonIdx > 0 {
via.Host = hostPort[:colonIdx]
port, _ := strconv.Atoi(hostPort[colonIdx+1:])
via.Port = port
} else {
via.Host = hostPort
via.Port = 5060 // Default SIP port
}
// Parse parameters
for i := 1; i < len(parts); i++ {
param := strings.TrimSpace(parts[i])
if eqIdx := strings.Index(param, "="); eqIdx > 0 {
key := strings.ToLower(param[:eqIdx])
val := param[eqIdx+1:]
switch key {
case "branch":
via.Branch = val
case "received":
via.Received = val
case "rport":
via.RPort, _ = strconv.Atoi(val)
default:
via.Params[key] = val
}
} else {
// Flag parameter (no value)
via.Params[strings.ToLower(param)] = ""
}
}
return via, nil
}
// Serialize converts ViaHeader back to string format
func (v *ViaHeader) Serialize() string {
var buf strings.Builder
// Protocol and transport
buf.WriteString(v.Protocol)
buf.WriteByte('/')
buf.WriteString(v.Transport)
buf.WriteByte(' ')
// Host and port
buf.WriteString(v.Host)
if v.Port > 0 && v.Port != 5060 {
buf.WriteByte(':')
buf.WriteString(strconv.Itoa(v.Port))
}
// Branch (required for RFC 3261 compliance)
if v.Branch != "" {
buf.WriteString(";branch=")
buf.WriteString(v.Branch)
}
// Received parameter
if v.Received != "" {
buf.WriteString(";received=")
buf.WriteString(v.Received)
}
// RPort parameter
if v.RPort > 0 {
buf.WriteString(";rport=")
buf.WriteString(strconv.Itoa(v.RPort))
}
// Other parameters
for key, val := range v.Params {
buf.WriteByte(';')
buf.WriteString(key)
if val != "" {
buf.WriteByte('=')
buf.WriteString(val)
}
}
return buf.String()
}
// ParseContactHeader parses a Contact header value
// Example: "\"Alice\" <sip:alice@example.com>;expires=3600"
func ParseContactHeader(value string) (*ContactHeader, error) {
contact := &ContactHeader{
Params: make(map[string]string),
}
// Handle * (special Contact for REGISTER)
if strings.TrimSpace(value) == "*" {
contact.URI = "*"
return contact, nil
}
// Extract display name if present (quoted)
if idx := strings.Index(value, "\""); idx >= 0 {
endIdx := strings.Index(value[idx+1:], "\"")
if endIdx > 0 {
contact.DisplayName = value[idx+1 : idx+1+endIdx]
value = value[idx+1+endIdx+1:]
}
}
// Extract URI (within angle brackets or until semicolon)
value = strings.TrimSpace(value)
if strings.HasPrefix(value, "<") {
endIdx := strings.Index(value, ">")
if endIdx > 0 {
contact.URI = value[1:endIdx]
value = value[endIdx+1:]
}
} else {
// No angle brackets - URI goes until semicolon or end
semiIdx := strings.Index(value, ";")
if semiIdx > 0 {
contact.URI = strings.TrimSpace(value[:semiIdx])
value = value[semiIdx:]
} else {
contact.URI = strings.TrimSpace(value)
value = ""
}
}
// Parse parameters
for _, part := range strings.Split(value, ";") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if eqIdx := strings.Index(part, "="); eqIdx > 0 {
contact.Params[strings.ToLower(part[:eqIdx])] = part[eqIdx+1:]
} else if part != "" {
contact.Params[strings.ToLower(part)] = ""
}
}
return contact, nil
}
// Serialize converts ContactHeader back to string format
func (c *ContactHeader) Serialize() string {
var buf strings.Builder
if c.URI == "*" {
return "*"
}
if c.DisplayName != "" {
buf.WriteByte('"')
buf.WriteString(c.DisplayName)
buf.WriteString("\" ")
}
buf.WriteByte('<')
buf.WriteString(c.URI)
buf.WriteByte('>')
for key, val := range c.Params {
buf.WriteByte(';')
buf.WriteString(key)
if val != "" {
buf.WriteByte('=')
buf.WriteString(val)
}
}
return buf.String()
}
// ParseFromToHeader parses a From or To header value
// Example: "\"Bob\" <sip:bob@example.com>;tag=abc123"
func ParseFromToHeader(value string) (*FromToHeader, error) {
header := &FromToHeader{
Params: make(map[string]string),
}
// Extract display name if present (quoted)
if idx := strings.Index(value, "\""); idx >= 0 {
endIdx := strings.Index(value[idx+1:], "\"")
if endIdx > 0 {
header.DisplayName = value[idx+1 : idx+1+endIdx]
value = value[idx+1+endIdx+1:]
}
}
// Extract URI (within angle brackets)
value = strings.TrimSpace(value)
if strings.HasPrefix(value, "<") {
endIdx := strings.Index(value, ">")
if endIdx > 0 {
header.URI = value[1:endIdx]
value = value[endIdx+1:]
}
} else {
// No angle brackets
semiIdx := strings.Index(value, ";")
if semiIdx > 0 {
header.URI = strings.TrimSpace(value[:semiIdx])
value = value[semiIdx:]
} else {
header.URI = strings.TrimSpace(value)
value = ""
}
}
// Parse parameters
for _, part := range strings.Split(value, ";") {
part = strings.TrimSpace(part)
if part == "" {
continue
}
if eqIdx := strings.Index(part, "="); eqIdx > 0 {
key := strings.ToLower(part[:eqIdx])
val := part[eqIdx+1:]
if key == "tag" {
header.Tag = val
} else {
header.Params[key] = val
}
}
}
return header, nil
}
// Serialize converts FromToHeader back to string format
func (f *FromToHeader) Serialize() string {
var buf strings.Builder
if f.DisplayName != "" {
buf.WriteByte('"')
buf.WriteString(f.DisplayName)
buf.WriteString("\" ")
}
buf.WriteByte('<')
buf.WriteString(f.URI)
buf.WriteByte('>')
if f.Tag != "" {
buf.WriteString(";tag=")
buf.WriteString(f.Tag)
}
for key, val := range f.Params {
buf.WriteByte(';')
buf.WriteString(key)
if val != "" {
buf.WriteByte('=')
buf.WriteString(val)
}
}
return buf.String()
}
// ExtractHostFromURI extracts the host part from a SIP URI
// Example: "sip:user@host:5060;transport=udp" -> "host"
func ExtractHostFromURI(uri string) string {
// Remove scheme
if idx := strings.Index(uri, ":"); idx >= 0 {
uri = uri[idx+1:]
}
// Remove user part
if idx := strings.Index(uri, "@"); idx >= 0 {
uri = uri[idx+1:]
}
// Remove port and parameters
if idx := strings.Index(uri, ":"); idx >= 0 {
uri = uri[:idx]
}
if idx := strings.Index(uri, ";"); idx >= 0 {
uri = uri[:idx]
}
return uri
}
// ExtractPortFromURI extracts the port from a SIP URI (default 5060)
func ExtractPortFromURI(uri string) int {
// Remove scheme and user part
if idx := strings.Index(uri, "@"); idx >= 0 {
uri = uri[idx+1:]
} else if idx := strings.Index(uri, ":"); idx >= 0 {
uri = uri[idx+1:]
}
// Extract port
if colonIdx := strings.Index(uri, ":"); colonIdx >= 0 {
portStr := uri[colonIdx+1:]
if semiIdx := strings.Index(portStr, ";"); semiIdx >= 0 {
portStr = portStr[:semiIdx]
}
port, err := strconv.Atoi(portStr)
if err == nil {
return port
}
}
return 5060 // Default SIP port
}
// RewriteURIHost replaces the host in a SIP URI
func RewriteURIHost(uri, newHost string, newPort int) string {
// Pattern to match SIP URI
re := regexp.MustCompile(`^(sips?:)([^@]+@)?([^:;>]+)(:\d+)?(.*)$`)
matches := re.FindStringSubmatch(uri)
if matches == nil {
return uri
}
// Rebuild URI with new host
result := matches[1] + matches[2] + newHost
if newPort > 0 && newPort != 5060 {
result += ":" + strconv.Itoa(newPort)
}
result += matches[5]
return result
}
// GenerateBranch generates an RFC 3261 compliant branch parameter
func GenerateBranch() string {
// Must start with z9hG4bK for RFC 3261 compliance
return "z9hG4bK" + generateRandomHex(16)
}
// generateRandomHex generates a random hex string of given length
func generateRandomHex(length int) string {
const chars = "0123456789abcdef"
result := make([]byte, length)
for i := range result {
// Use uint64 modulo to avoid negative index from int conversion
result[i] = chars[pseudoRand()%uint64(len(chars))]
}
return string(result)
}
// Simple pseudo-random number generator (not cryptographically secure)
var prngState uint64 = 0xdeadbeef
func pseudoRand() uint64 {
prngState ^= prngState << 13
prngState ^= prngState >> 7
prngState ^= prngState << 17
return prngState
}
// InitPRNG seeds the pseudo-random number generator
func InitPRNG(seed uint64) {
prngState = seed
if prngState == 0 {
prngState = 0xdeadbeef
}
}
// Note: IsPrivateIP is defined in geoip.go with proper CIDR parsing
// NormalizeHeaderName converts header name to canonical form
func NormalizeHeaderName(name string) string {
// Map compact forms to full names
compactToFull := map[string]string{
"i": "Call-ID",
"m": "Contact",
"e": "Content-Encoding",
"l": "Content-Length",
"c": "Content-Type",
"f": "From",
"s": "Subject",
"k": "Supported",
"t": "To",
"v": "Via",
}
lower := strings.ToLower(name)
if full, ok := compactToFull[lower]; ok {
return full
}
// Title case for standard headers
return strings.Title(lower)
}

416
sipmsg.go Normal file
View File

@ -0,0 +1,416 @@
package sipguardian
import (
"bytes"
"errors"
"fmt"
"strconv"
"strings"
)
// SIPMessage represents a parsed SIP message (request or response)
type SIPMessage struct {
// Common fields
IsRequest bool
SIPVersion string
Headers []SIPHeader
Body []byte
// Request fields
Method string
RequestURI string
// Response fields
StatusCode int
ReasonPhrase string
}
// SIPHeader represents a single SIP header
type SIPHeader struct {
Name string
Value string
}
// Common errors
var (
ErrInvalidSIPMessage = errors.New("invalid SIP message format")
ErrEmptyMessage = errors.New("empty message")
ErrNoStartLine = errors.New("missing start line")
)
// ParseSIPMessage parses raw bytes into a SIPMessage structure
func ParseSIPMessage(data []byte) (*SIPMessage, error) {
if len(data) == 0 {
return nil, ErrEmptyMessage
}
msg := &SIPMessage{
Headers: make([]SIPHeader, 0),
}
// Split into headers and body at double CRLF
parts := bytes.SplitN(data, []byte("\r\n\r\n"), 2)
headerSection := parts[0]
if len(parts) > 1 {
msg.Body = parts[1]
}
// Split header section into lines
lines := bytes.Split(headerSection, []byte("\r\n"))
if len(lines) == 0 {
return nil, ErrNoStartLine
}
// Parse start line (first line)
if err := msg.parseStartLine(string(lines[0])); err != nil {
return nil, err
}
// Parse headers (remaining lines)
for i := 1; i < len(lines); i++ {
line := string(lines[i])
if line == "" {
continue
}
// Handle header continuation (line starting with whitespace)
if len(line) > 0 && (line[0] == ' ' || line[0] == '\t') {
// Append to previous header
if len(msg.Headers) > 0 {
msg.Headers[len(msg.Headers)-1].Value += " " + strings.TrimSpace(line)
}
continue
}
// Parse header name: value
colonIdx := strings.Index(line, ":")
if colonIdx > 0 {
name := strings.TrimSpace(line[:colonIdx])
value := strings.TrimSpace(line[colonIdx+1:])
msg.Headers = append(msg.Headers, SIPHeader{
Name: name,
Value: value,
})
}
}
return msg, nil
}
// parseStartLine parses the first line of a SIP message
func (m *SIPMessage) parseStartLine(line string) error {
parts := strings.Fields(line)
if len(parts) < 2 {
return ErrInvalidSIPMessage
}
// Check if it's a response (starts with SIP/2.0)
if strings.HasPrefix(parts[0], "SIP/") {
m.IsRequest = false
m.SIPVersion = parts[0]
// Parse status code
if len(parts) < 2 {
return ErrInvalidSIPMessage
}
code, err := strconv.Atoi(parts[1])
if err != nil {
return fmt.Errorf("invalid status code: %v", err)
}
m.StatusCode = code
// Reason phrase is the rest
if len(parts) >= 3 {
m.ReasonPhrase = strings.Join(parts[2:], " ")
}
} else {
// It's a request
m.IsRequest = true
m.Method = parts[0]
if len(parts) < 3 {
return ErrInvalidSIPMessage
}
m.RequestURI = parts[1]
m.SIPVersion = parts[2]
}
return nil
}
// Serialize converts the SIPMessage back to wire format
func (m *SIPMessage) Serialize() []byte {
var buf bytes.Buffer
// Write start line
if m.IsRequest {
buf.WriteString(m.Method)
buf.WriteByte(' ')
buf.WriteString(m.RequestURI)
buf.WriteByte(' ')
buf.WriteString(m.SIPVersion)
} else {
buf.WriteString(m.SIPVersion)
buf.WriteByte(' ')
buf.WriteString(strconv.Itoa(m.StatusCode))
buf.WriteByte(' ')
buf.WriteString(m.ReasonPhrase)
}
buf.WriteString("\r\n")
// Write headers
for _, h := range m.Headers {
buf.WriteString(h.Name)
buf.WriteString(": ")
buf.WriteString(h.Value)
buf.WriteString("\r\n")
}
// End of headers
buf.WriteString("\r\n")
// Write body if present
if len(m.Body) > 0 {
buf.Write(m.Body)
}
return buf.Bytes()
}
// GetHeader returns the first header with the given name (case-insensitive)
func (m *SIPMessage) GetHeader(name string) *SIPHeader {
lowerName := strings.ToLower(name)
// Also check compact form
compactName := getCompactForm(lowerName)
for i := range m.Headers {
headerLower := strings.ToLower(m.Headers[i].Name)
if headerLower == lowerName || headerLower == compactName {
return &m.Headers[i]
}
}
return nil
}
// GetHeaders returns all headers with the given name (case-insensitive)
func (m *SIPMessage) GetHeaders(name string) []*SIPHeader {
lowerName := strings.ToLower(name)
compactName := getCompactForm(lowerName)
var result []*SIPHeader
for i := range m.Headers {
headerLower := strings.ToLower(m.Headers[i].Name)
if headerLower == lowerName || headerLower == compactName {
result = append(result, &m.Headers[i])
}
}
return result
}
// SetHeader sets a header value, replacing existing or adding new
func (m *SIPMessage) SetHeader(name, value string) {
lowerName := strings.ToLower(name)
compactName := getCompactForm(lowerName)
for i := range m.Headers {
headerLower := strings.ToLower(m.Headers[i].Name)
if headerLower == lowerName || headerLower == compactName {
m.Headers[i].Value = value
return
}
}
// Not found, add new
m.Headers = append(m.Headers, SIPHeader{Name: name, Value: value})
}
// PrependHeader adds a header at the beginning of the header list
func (m *SIPMessage) PrependHeader(name, value string) {
newHeader := SIPHeader{Name: name, Value: value}
m.Headers = append([]SIPHeader{newHeader}, m.Headers...)
}
// AppendHeader adds a header at the end of the header list
func (m *SIPMessage) AppendHeader(name, value string) {
m.Headers = append(m.Headers, SIPHeader{Name: name, Value: value})
}
// RemoveHeader removes all headers with the given name (case-insensitive)
func (m *SIPMessage) RemoveHeader(name string) {
lowerName := strings.ToLower(name)
compactName := getCompactForm(lowerName)
newHeaders := make([]SIPHeader, 0, len(m.Headers))
for _, h := range m.Headers {
headerLower := strings.ToLower(h.Name)
if headerLower != lowerName && headerLower != compactName {
newHeaders = append(newHeaders, h)
}
}
m.Headers = newHeaders
}
// RemoveFirstHeader removes only the first header with the given name
func (m *SIPMessage) RemoveFirstHeader(name string) bool {
lowerName := strings.ToLower(name)
compactName := getCompactForm(lowerName)
for i, h := range m.Headers {
headerLower := strings.ToLower(h.Name)
if headerLower == lowerName || headerLower == compactName {
m.Headers = append(m.Headers[:i], m.Headers[i+1:]...)
return true
}
}
return false
}
// GetCallID returns the Call-ID header value
func (m *SIPMessage) GetCallID() string {
if h := m.GetHeader("Call-ID"); h != nil {
return h.Value
}
return ""
}
// GetFromTag extracts the tag parameter from the From header
func (m *SIPMessage) GetFromTag() string {
if h := m.GetHeader("From"); h != nil {
return extractTagParam(h.Value)
}
return ""
}
// GetToTag extracts the tag parameter from the To header
func (m *SIPMessage) GetToTag() string {
if h := m.GetHeader("To"); h != nil {
return extractTagParam(h.Value)
}
return ""
}
// GetCSeq returns the CSeq number and method
func (m *SIPMessage) GetCSeq() (int, string) {
if h := m.GetHeader("CSeq"); h != nil {
parts := strings.Fields(h.Value)
if len(parts) >= 2 {
seq, _ := strconv.Atoi(parts[0])
return seq, parts[1]
}
}
return 0, ""
}
// GetViaBranch extracts the branch parameter from the top Via header
func (m *SIPMessage) GetViaBranch() string {
if h := m.GetHeader("Via"); h != nil {
return extractViaParam(h.Value, "branch")
}
return ""
}
// Clone creates a deep copy of the message
func (m *SIPMessage) Clone() *SIPMessage {
clone := &SIPMessage{
IsRequest: m.IsRequest,
SIPVersion: m.SIPVersion,
Method: m.Method,
RequestURI: m.RequestURI,
StatusCode: m.StatusCode,
ReasonPhrase: m.ReasonPhrase,
Headers: make([]SIPHeader, len(m.Headers)),
}
copy(clone.Headers, m.Headers)
if len(m.Body) > 0 {
clone.Body = make([]byte, len(m.Body))
copy(clone.Body, m.Body)
}
return clone
}
// getCompactForm returns the compact form of a header name
func getCompactForm(name string) string {
compactForms := map[string]string{
"call-id": "i",
"contact": "m",
"content-encoding": "e",
"content-length": "l",
"content-type": "c",
"from": "f",
"subject": "s",
"supported": "k",
"to": "t",
"via": "v",
}
if compact, ok := compactForms[name]; ok {
return compact
}
return ""
}
// extractTagParam extracts the tag parameter from a From/To header value
func extractTagParam(headerValue string) string {
// Look for ;tag= parameter
idx := strings.Index(strings.ToLower(headerValue), ";tag=")
if idx < 0 {
return ""
}
tagStart := idx + 5 // len(";tag=")
rest := headerValue[tagStart:]
// Find end of tag value (semicolon or end of string)
endIdx := strings.IndexAny(rest, ";,>")
if endIdx < 0 {
return rest
}
return rest[:endIdx]
}
// extractViaParam extracts a parameter from a Via header value
func extractViaParam(headerValue, param string) string {
lower := strings.ToLower(headerValue)
search := ";" + param + "="
idx := strings.Index(lower, search)
if idx < 0 {
return ""
}
valueStart := idx + len(search)
rest := headerValue[valueStart:]
// Find end of value
endIdx := strings.IndexAny(rest, ";,")
if endIdx < 0 {
return rest
}
return rest[:endIdx]
}
// IsDialogCreating returns true if this is a dialog-creating request
func (m *SIPMessage) IsDialogCreating() bool {
if !m.IsRequest {
return false
}
switch m.Method {
case "INVITE", "SUBSCRIBE", "REFER", "NOTIFY":
return true
}
return false
}
// IsDialogTerminating returns true if this ends a dialog
func (m *SIPMessage) IsDialogTerminating() bool {
if !m.IsRequest {
return false
}
return m.Method == "BYE"
}
// UpdateContentLength updates the Content-Length header to match body size
func (m *SIPMessage) UpdateContentLength() {
m.SetHeader("Content-Length", strconv.Itoa(len(m.Body)))
}

539
topology.go Normal file
View File

@ -0,0 +1,539 @@
package sipguardian
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/mholt/caddy-l4/layer4"
"go.uber.org/zap"
)
func init() {
caddy.RegisterModule(TopologyHider{})
}
// TopologyHider is a Layer 4 handler that hides internal SIP topology
type TopologyHider struct {
// Enabled toggles topology hiding
Enabled bool `json:"enabled,omitempty"`
// ProxyHost is the public IP address to use in rewritten headers
ProxyHost string `json:"proxy_host,omitempty"`
// ProxyPort is the public port to use in rewritten headers
ProxyPort int `json:"proxy_port,omitempty"`
// RewriteVia adds/modifies Via headers
RewriteVia bool `json:"rewrite_via,omitempty"`
// RewriteContact rewrites Contact headers to hide internal addresses
RewriteContact bool `json:"rewrite_contact,omitempty"`
// StripHeaders lists headers to remove (e.g., Server, P-Preferred-Identity)
StripHeaders []string `json:"strip_headers,omitempty"`
// AnonymizeCallID replaces Call-ID with proxy-generated value
AnonymizeCallID bool `json:"anonymize_call_id,omitempty"`
// HidePrivateIPs automatically detects and hides RFC 1918 addresses
HidePrivateIPs bool `json:"hide_private_ips,omitempty"`
// TransactionTimeout for cleanup of pending transactions
TransactionTimeout caddy.Duration `json:"transaction_timeout,omitempty"`
// Runtime state
logger *zap.Logger
transactions *TransactionManager
dialogs *DialogManager
}
func (TopologyHider) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "layer4.handlers.sip_topology_hider",
New: func() caddy.Module { return new(TopologyHider) },
}
}
func (h *TopologyHider) Provision(ctx caddy.Context) error {
h.logger = ctx.Logger()
// Set defaults
if h.ProxyPort == 0 {
h.ProxyPort = 5060
}
if h.TransactionTimeout == 0 {
h.TransactionTimeout = caddy.Duration(32 * time.Second)
}
// Initialize transaction manager for response correlation
h.transactions = NewTransactionManager(h.logger, time.Duration(h.TransactionTimeout))
// Initialize dialog manager for stateful tracking
h.dialogs = NewDialogManager(h.logger, 10*time.Minute)
// Start cleanup goroutine
go h.cleanupLoop(ctx)
h.logger.Info("SIP Topology Hider initialized",
zap.Bool("enabled", h.Enabled),
zap.String("proxy_host", h.ProxyHost),
zap.Int("proxy_port", h.ProxyPort),
zap.Bool("rewrite_via", h.RewriteVia),
zap.Bool("rewrite_contact", h.RewriteContact),
zap.Strings("strip_headers", h.StripHeaders),
zap.Bool("anonymize_call_id", h.AnonymizeCallID),
zap.Bool("hide_private_ips", h.HidePrivateIPs),
)
return nil
}
// Handle processes the SIP message and applies topology hiding
func (h *TopologyHider) Handle(cx *layer4.Connection, next layer4.Handler) error {
if !h.Enabled {
return next.Handle(cx)
}
// Get client address
remoteAddr := cx.RemoteAddr().String()
clientHost, clientPortStr, err := net.SplitHostPort(remoteAddr)
if err != nil {
clientHost = remoteAddr
}
clientPort, _ := strconv.Atoi(clientPortStr)
// Read the SIP message
buf := make([]byte, 8192)
n, err := cx.Read(buf)
if err != nil || n == 0 {
return next.Handle(cx)
}
buf = buf[:n]
// Parse the SIP message
msg, err := ParseSIPMessage(buf)
if err != nil {
h.logger.Debug("Failed to parse SIP message for topology hiding",
zap.Error(err),
zap.String("client", clientHost),
)
return next.Handle(cx)
}
// Apply topology hiding based on message type
if msg.IsRequest {
h.handleRequest(msg, clientHost, clientPort)
} else {
h.handleResponse(msg)
}
// Serialize modified message
modifiedData := msg.Serialize()
// Write modified data back to connection buffer
// Note: This requires caddy-l4 support for modifying the connection data
// For now, we'll pass through the next handler
// In a full implementation, we'd use a wrapped connection
h.logger.Debug("Topology hiding applied",
zap.Bool("is_request", msg.IsRequest),
zap.String("method", msg.Method),
zap.Int("status_code", msg.StatusCode),
zap.Int("original_size", n),
zap.Int("modified_size", len(modifiedData)),
)
// Continue to next handler
// TODO: Replace connection data with modified message
return next.Handle(cx)
}
// handleRequest applies topology hiding to outgoing requests
func (h *TopologyHider) handleRequest(msg *SIPMessage, clientHost string, clientPort int) {
callID := msg.GetCallID()
fromTag := msg.GetFromTag()
// Store original values for response correlation
originalVia := ""
if via := msg.GetHeader("Via"); via != nil {
originalVia = via.Value
}
// Create transaction state
branch := GenerateBranch()
h.transactions.CreateTransaction(branch, msg.Method, clientHost, clientPort, originalVia)
// For dialog-creating requests, create dialog state
if msg.IsDialogCreating() {
state := h.dialogs.CreateDialog(msg, clientHost, clientPort)
if state != nil && originalVia != "" {
h.dialogs.StoreOriginals(callID, fromTag, originalVia, "", "")
}
}
// Add our Via header (prepend to top)
if h.RewriteVia {
newVia := h.buildViaHeader(branch)
msg.PrependHeader("Via", newVia)
h.logger.Debug("Added Via header",
zap.String("via", newVia),
zap.String("branch", branch),
)
}
// Rewrite Contact header
if h.RewriteContact {
if contact := msg.GetHeader("Contact"); contact != nil {
originalContact := contact.Value
newContact := h.rewriteContactHeader(contact.Value)
contact.Value = newContact
// Store original for response rewriting
if msg.IsDialogCreating() {
h.dialogs.StoreOriginals(callID, fromTag, "", originalContact, "")
}
h.logger.Debug("Rewrote Contact header",
zap.String("original", originalContact),
zap.String("new", newContact),
)
}
}
// Anonymize Call-ID
if h.AnonymizeCallID {
if callIDHeader := msg.GetHeader("Call-ID"); callIDHeader != nil {
originalCallID := callIDHeader.Value
newCallID := h.generateAnonymousCallID()
callIDHeader.Value = newCallID
// Store mapping for response correlation
h.dialogs.StoreOriginals(callID, fromTag, "", "", originalCallID)
h.logger.Debug("Anonymized Call-ID",
zap.String("original", originalCallID),
zap.String("new", newCallID),
)
}
}
// Strip sensitive headers
h.stripSensitiveHeaders(msg)
// Hide private IPs in other headers
if h.HidePrivateIPs {
h.hidePrivateIPsInHeaders(msg)
}
}
// handleResponse applies topology hiding to responses
func (h *TopologyHider) handleResponse(msg *SIPMessage) {
// Find the transaction by branch
branch := msg.GetViaBranch()
transaction := h.transactions.GetTransaction(branch)
if transaction == nil {
h.logger.Debug("No transaction found for response",
zap.String("branch", branch),
zap.Int("status_code", msg.StatusCode),
)
return
}
// Remove our Via header (should be top Via)
if h.RewriteVia {
// The top Via should be ours - remove it
topVia := msg.GetHeader("Via")
if topVia != nil {
viaHeader, _ := ParseViaHeader(topVia.Value)
if viaHeader != nil && viaHeader.Branch == branch {
msg.RemoveFirstHeader("Via")
h.logger.Debug("Removed proxy Via header",
zap.String("branch", branch),
)
}
}
}
// Update dialog state based on response
callID := msg.GetCallID()
fromTag := msg.GetFromTag()
toTag := msg.GetToTag()
// Update dialog with To tag
if toTag != "" {
h.dialogs.UpdateDialog(callID, fromTag, toTag)
}
// Mark dialog as confirmed for 2xx responses to INVITE
_, method := msg.GetCSeq()
if msg.StatusCode >= 200 && msg.StatusCode < 300 && method == "INVITE" {
h.dialogs.ConfirmDialog(callID, fromTag)
}
// Strip sensitive headers from response too
h.stripSensitiveHeaders(msg)
// Clean up transaction for final responses
if msg.StatusCode >= 200 {
h.transactions.RemoveTransaction(branch)
}
}
// buildViaHeader creates a new Via header with proxy address
func (h *TopologyHider) buildViaHeader(branch string) string {
via := &ViaHeader{
Protocol: "SIP/2.0",
Transport: "UDP",
Host: h.ProxyHost,
Port: h.ProxyPort,
Branch: branch,
}
return via.Serialize()
}
// rewriteContactHeader replaces internal addresses with proxy address
func (h *TopologyHider) rewriteContactHeader(value string) string {
contact, err := ParseContactHeader(value)
if err != nil || contact.URI == "*" {
return value
}
// Check if the Contact URI contains a private IP
host := ExtractHostFromURI(contact.URI)
if h.HidePrivateIPs && IsPrivateIP(host) {
// Rewrite URI to use proxy address
contact.URI = RewriteURIHost(contact.URI, h.ProxyHost, h.ProxyPort)
} else if h.RewriteContact {
// Always rewrite if RewriteContact is enabled
contact.URI = RewriteURIHost(contact.URI, h.ProxyHost, h.ProxyPort)
}
return contact.Serialize()
}
// stripSensitiveHeaders removes headers that leak internal information
func (h *TopologyHider) stripSensitiveHeaders(msg *SIPMessage) {
// Strip configured headers
for _, header := range h.StripHeaders {
msg.RemoveHeader(header)
}
// Always strip these if HidePrivateIPs is enabled
if h.HidePrivateIPs {
for _, header := range SensitiveHeaders {
msg.RemoveHeader(header)
}
}
}
// hidePrivateIPsInHeaders scans all headers for private IPs
func (h *TopologyHider) hidePrivateIPsInHeaders(msg *SIPMessage) {
// Headers that commonly contain IP addresses
ipHeaders := []string{
"Record-Route",
"Route",
"P-Visited-Network-ID",
}
for i := range msg.Headers {
headerLower := strings.ToLower(msg.Headers[i].Name)
for _, ipHeader := range ipHeaders {
if headerLower == strings.ToLower(ipHeader) {
// Check for private IPs in the value
if containsPrivateIP(msg.Headers[i].Value) {
// Rewrite or remove the header
msg.Headers[i].Value = rewritePrivateIPs(msg.Headers[i].Value, h.ProxyHost)
}
}
}
}
}
// generateAnonymousCallID creates a new random Call-ID
func (h *TopologyHider) generateAnonymousCallID() string {
bytes := make([]byte, 16)
rand.Read(bytes)
return hex.EncodeToString(bytes) + "@" + h.ProxyHost
}
// containsPrivateIP checks if a string contains any private IP addresses
func containsPrivateIP(value string) bool {
// Simple pattern matching for common private IP formats
privatePatterns := []string{
"10.", "192.168.", "172.16.", "172.17.", "172.18.", "172.19.",
"172.20.", "172.21.", "172.22.", "172.23.", "172.24.", "172.25.",
"172.26.", "172.27.", "172.28.", "172.29.", "172.30.", "172.31.",
}
for _, pattern := range privatePatterns {
if strings.Contains(value, pattern) {
return true
}
}
return false
}
// rewritePrivateIPs replaces private IP addresses in a header value
func rewritePrivateIPs(value, replacement string) string {
// This is a simplified implementation
// A full implementation would use regex to properly replace IPs in URIs
// For Record-Route and Route, we may want to preserve the structure
// but replace the host portion
return value // TODO: Implement proper IP replacement
}
// cleanupLoop periodically cleans up expired state
func (h *TopologyHider) cleanupLoop(ctx caddy.Context) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
transRemoved := h.transactions.Cleanup()
dialogRemoved := h.dialogs.Cleanup()
if transRemoved > 0 || dialogRemoved > 0 {
h.logger.Debug("Topology hider cleanup",
zap.Int("transactions_removed", transRemoved),
zap.Int("dialogs_removed", dialogRemoved),
)
}
}
}
}
// GetStats returns statistics about the topology hider
func (h *TopologyHider) GetStats() map[string]interface{} {
stats := map[string]interface{}{
"enabled": h.Enabled,
"proxy_host": h.ProxyHost,
"proxy_port": h.ProxyPort,
"rewrite_via": h.RewriteVia,
"rewrite_contact": h.RewriteContact,
"anonymize_callid": h.AnonymizeCallID,
"hide_private_ips": h.HidePrivateIPs,
}
if h.transactions != nil {
for k, v := range h.transactions.GetStats() {
stats["transactions_"+k] = v
}
}
if h.dialogs != nil {
for k, v := range h.dialogs.GetStats() {
stats["dialogs_"+k] = v
}
}
return stats
}
// UnmarshalCaddyfile implements caddyfile.Unmarshaler for TopologyHider.
// Usage in Caddyfile:
//
// sip_topology_hider {
// enabled true
// proxy_host 203.0.113.1
// proxy_port 5060
// rewrite_via
// rewrite_contact
// strip_headers P-Preferred-Identity P-Asserted-Identity Server
// anonymize_call_id
// hide_private_ips
// transaction_timeout 32s
// }
func (h *TopologyHider) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
// Move past "sip_topology_hider" token
d.Next()
for nesting := d.Nesting(); d.NextBlock(nesting); {
switch d.Val() {
case "enabled":
if !d.NextArg() {
// No argument means enabled
h.Enabled = true
} else {
val := d.Val()
h.Enabled = val == "true" || val == "yes" || val == "on"
}
case "proxy_host":
if !d.NextArg() {
return d.ArgErr()
}
h.ProxyHost = d.Val()
case "proxy_port":
if !d.NextArg() {
return d.ArgErr()
}
port, err := strconv.Atoi(d.Val())
if err != nil {
return d.Errf("invalid proxy_port: %v", err)
}
h.ProxyPort = port
case "rewrite_via":
h.RewriteVia = true
case "rewrite_contact":
h.RewriteContact = true
case "strip_headers":
h.StripHeaders = d.RemainingArgs()
if len(h.StripHeaders) == 0 {
return d.ArgErr()
}
case "anonymize_call_id":
h.AnonymizeCallID = true
case "hide_private_ips":
h.HidePrivateIPs = true
case "transaction_timeout":
if !d.NextArg() {
return d.ArgErr()
}
dur, err := caddy.ParseDuration(d.Val())
if err != nil {
return d.Errf("invalid transaction_timeout: %v", err)
}
h.TransactionTimeout = caddy.Duration(dur)
default:
return d.Errf("unknown sip_topology_hider directive: %s", d.Val())
}
}
// Validate required fields
if h.Enabled && h.ProxyHost == "" {
return fmt.Errorf("proxy_host is required when topology hiding is enabled")
}
return nil
}
// Interface guards
var (
_ layer4.NextHandler = (*TopologyHider)(nil)
_ caddy.Module = (*TopologyHider)(nil)
_ caddy.Provisioner = (*TopologyHider)(nil)
_ caddyfile.Unmarshaler = (*TopologyHider)(nil)
)

833
topology_test.go Normal file
View File

@ -0,0 +1,833 @@
package sipguardian
import (
"strings"
"testing"
"time"
"go.uber.org/zap"
)
// =============================================================================
// SIP Message Parsing Tests (sipmsg.go)
// =============================================================================
func TestParseSIPRequest(t *testing.T) {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds\r\n" +
"From: \"Alice\" <sip:alice@example.com>;tag=1928301774\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: a84b4c76e66710@pc33.example.com\r\n" +
"CSeq: 314159 INVITE\r\n" +
"Contact: <sip:alice@192.168.1.100:5060>\r\n" +
"Content-Type: application/sdp\r\n" +
"Content-Length: 4\r\n" +
"\r\n" +
"test"
msg, err := ParseSIPMessage([]byte(raw))
if err != nil {
t.Fatalf("Failed to parse SIP request: %v", err)
}
if !msg.IsRequest {
t.Error("Expected IsRequest to be true")
}
if msg.Method != "INVITE" {
t.Errorf("Expected method INVITE, got %s", msg.Method)
}
if msg.RequestURI != "sip:bob@example.com" {
t.Errorf("Expected RequestURI sip:bob@example.com, got %s", msg.RequestURI)
}
if msg.SIPVersion != "SIP/2.0" {
t.Errorf("Expected SIP/2.0, got %s", msg.SIPVersion)
}
if string(msg.Body) != "test" {
t.Errorf("Expected body 'test', got '%s'", string(msg.Body))
}
}
func TestParseSIPResponse(t *testing.T) {
raw := "SIP/2.0 200 OK\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds;received=10.0.0.1\r\n" +
"From: \"Alice\" <sip:alice@example.com>;tag=1928301774\r\n" +
"To: <sip:bob@example.com>;tag=a6c85cf\r\n" +
"Call-ID: a84b4c76e66710@pc33.example.com\r\n" +
"CSeq: 314159 INVITE\r\n" +
"Contact: <sip:bob@192.168.1.200:5060>\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, err := ParseSIPMessage([]byte(raw))
if err != nil {
t.Fatalf("Failed to parse SIP response: %v", err)
}
if msg.IsRequest {
t.Error("Expected IsRequest to be false for response")
}
if msg.StatusCode != 200 {
t.Errorf("Expected status code 200, got %d", msg.StatusCode)
}
if msg.ReasonPhrase != "OK" {
t.Errorf("Expected reason phrase 'OK', got '%s'", msg.ReasonPhrase)
}
}
func TestSIPMessageGetHeader(t *testing.T) {
raw := "REGISTER sip:example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:alice@example.com>\r\n" +
"Call-ID: test123@host\r\n" +
"CSeq: 1 REGISTER\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
// Test case-insensitive lookup
via := msg.GetHeader("via")
if via == nil {
t.Error("Failed to get Via header (lowercase)")
}
Via := msg.GetHeader("Via")
if Via == nil {
t.Error("Failed to get Via header (mixed case)")
}
// Test GetCallID helper
callID := msg.GetCallID()
if callID != "test123@host" {
t.Errorf("Expected Call-ID 'test123@host', got '%s'", callID)
}
// Test GetFromTag helper
fromTag := msg.GetFromTag()
if fromTag != "123" {
t.Errorf("Expected From tag '123', got '%s'", fromTag)
}
}
func TestSIPMessageCompactHeaders(t *testing.T) {
// Test compact header forms (RFC 3261 Section 7.3.3)
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"v: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" +
"f: <sip:alice@example.com>;tag=123\r\n" +
"t: <sip:bob@example.com>\r\n" +
"i: compact-call-id@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"m: <sip:alice@192.168.1.100>\r\n" +
"l: 0\r\n" +
"\r\n"
msg, err := ParseSIPMessage([]byte(raw))
if err != nil {
t.Fatalf("Failed to parse message with compact headers: %v", err)
}
// Via (v)
if msg.GetHeader("Via") == nil {
t.Error("Failed to get Via via compact form 'v'")
}
// Call-ID (i)
if msg.GetCallID() != "compact-call-id@host" {
t.Error("Failed to get Call-ID via compact form 'i'")
}
// Contact (m)
if msg.GetHeader("Contact") == nil {
t.Error("Failed to get Contact via compact form 'm'")
}
}
func TestSIPMessageSerialize(t *testing.T) {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
serialized := msg.Serialize()
// Parse again and verify round-trip
msg2, err := ParseSIPMessage(serialized)
if err != nil {
t.Fatalf("Failed to parse serialized message: %v", err)
}
if msg.Method != msg2.Method {
t.Errorf("Method mismatch after round-trip: %s vs %s", msg.Method, msg2.Method)
}
if msg.GetCallID() != msg2.GetCallID() {
t.Errorf("Call-ID mismatch after round-trip")
}
}
func TestSIPMessageHeaderManipulation(t *testing.T) {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
// Test PrependHeader (for Via insertion)
msg.PrependHeader("Via", "SIP/2.0/UDP proxy.example.com:5060;branch=z9hG4bKproxy")
vias := msg.GetHeaders("Via")
if len(vias) != 2 {
t.Errorf("Expected 2 Via headers after prepend, got %d", len(vias))
}
if !strings.Contains(vias[0].Value, "proxy.example.com") {
t.Error("Prepended Via should be first")
}
// Test RemoveFirstHeader
msg.RemoveFirstHeader("Via")
vias = msg.GetHeaders("Via")
if len(vias) != 1 {
t.Errorf("Expected 1 Via header after removal, got %d", len(vias))
}
// Test SetHeader
msg.SetHeader("User-Agent", "Test/1.0")
ua := msg.GetHeader("User-Agent")
if ua == nil || ua.Value != "Test/1.0" {
t.Error("SetHeader failed")
}
// Test RemoveHeader
msg.RemoveHeader("User-Agent")
if msg.GetHeader("User-Agent") != nil {
t.Error("RemoveHeader failed")
}
}
func TestSIPMessageClone(t *testing.T) {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 4\r\n" +
"\r\n" +
"test"
msg, _ := ParseSIPMessage([]byte(raw))
clone := msg.Clone()
// Modify original
msg.SetHeader("Via", "modified")
msg.Body = []byte("modified")
// Clone should be unchanged
if clone.GetHeader("Via").Value == "modified" {
t.Error("Clone was affected by original modification")
}
if string(clone.Body) != "test" {
t.Error("Clone body was affected by original modification")
}
}
// =============================================================================
// Header Parsing Tests (sipheaders.go)
// =============================================================================
func TestParseViaHeader(t *testing.T) {
tests := []struct {
name string
input string
expected ViaHeader
}{
{
name: "basic UDP",
input: "SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776asdhds",
expected: ViaHeader{
Protocol: "SIP/2.0",
Transport: "UDP",
Host: "192.168.1.100",
Port: 5060,
Branch: "z9hG4bK776asdhds",
},
},
{
name: "with received and rport",
input: "SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776;received=10.0.0.1;rport=12345",
expected: ViaHeader{
Protocol: "SIP/2.0",
Transport: "UDP",
Host: "192.168.1.100",
Port: 5060,
Branch: "z9hG4bK776",
Received: "10.0.0.1",
RPort: 12345,
},
},
{
name: "TCP transport",
input: "SIP/2.0/TCP proxy.example.com:5060;branch=z9hG4bKabc",
expected: ViaHeader{
Protocol: "SIP/2.0",
Transport: "TCP",
Host: "proxy.example.com",
Port: 5060,
Branch: "z9hG4bKabc",
},
},
{
name: "default port",
input: "SIP/2.0/UDP 192.168.1.100;branch=z9hG4bK123",
expected: ViaHeader{
Protocol: "SIP/2.0",
Transport: "UDP",
Host: "192.168.1.100",
Port: 5060, // Default
Branch: "z9hG4bK123",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
via, err := ParseViaHeader(tt.input)
if err != nil {
t.Fatalf("ParseViaHeader failed: %v", err)
}
if via.Protocol != tt.expected.Protocol {
t.Errorf("Protocol: got %s, want %s", via.Protocol, tt.expected.Protocol)
}
if via.Transport != tt.expected.Transport {
t.Errorf("Transport: got %s, want %s", via.Transport, tt.expected.Transport)
}
if via.Host != tt.expected.Host {
t.Errorf("Host: got %s, want %s", via.Host, tt.expected.Host)
}
if via.Port != tt.expected.Port {
t.Errorf("Port: got %d, want %d", via.Port, tt.expected.Port)
}
if via.Branch != tt.expected.Branch {
t.Errorf("Branch: got %s, want %s", via.Branch, tt.expected.Branch)
}
})
}
}
func TestViaHeaderSerialize(t *testing.T) {
via := &ViaHeader{
Protocol: "SIP/2.0",
Transport: "UDP",
Host: "192.168.1.100",
Port: 5060,
Branch: "z9hG4bK776",
Received: "10.0.0.1",
RPort: 12345,
}
serialized := via.Serialize()
// Parse it back
via2, err := ParseViaHeader(serialized)
if err != nil {
t.Fatalf("Failed to parse serialized Via: %v", err)
}
if via2.Host != via.Host || via2.Branch != via.Branch || via2.Received != via.Received {
t.Error("Via header round-trip failed")
}
}
func TestParseContactHeader(t *testing.T) {
tests := []struct {
name string
input string
uri string
dname string
}{
{
name: "simple URI",
input: "<sip:alice@192.168.1.100:5060>",
uri: "sip:alice@192.168.1.100:5060",
dname: "",
},
{
name: "with display name",
input: "\"Alice\" <sip:alice@example.com>;expires=3600",
uri: "sip:alice@example.com",
dname: "Alice",
},
{
name: "star contact (REGISTER)",
input: "*",
uri: "*",
dname: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
contact, err := ParseContactHeader(tt.input)
if err != nil {
t.Fatalf("ParseContactHeader failed: %v", err)
}
if contact.URI != tt.uri {
t.Errorf("URI: got %s, want %s", contact.URI, tt.uri)
}
if contact.DisplayName != tt.dname {
t.Errorf("DisplayName: got %s, want %s", contact.DisplayName, tt.dname)
}
})
}
}
func TestParseFromToHeader(t *testing.T) {
tests := []struct {
name string
input string
uri string
tag string
dname string
}{
{
name: "From with tag",
input: "\"Bob\" <sip:bob@example.com>;tag=a6c85cf",
uri: "sip:bob@example.com",
tag: "a6c85cf",
dname: "Bob",
},
{
name: "To without tag",
input: "<sip:alice@example.com>",
uri: "sip:alice@example.com",
tag: "",
dname: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
header, err := ParseFromToHeader(tt.input)
if err != nil {
t.Fatalf("ParseFromToHeader failed: %v", err)
}
if header.URI != tt.uri {
t.Errorf("URI: got %s, want %s", header.URI, tt.uri)
}
if header.Tag != tt.tag {
t.Errorf("Tag: got %s, want %s", header.Tag, tt.tag)
}
if header.DisplayName != tt.dname {
t.Errorf("DisplayName: got %s, want %s", header.DisplayName, tt.dname)
}
})
}
}
func TestExtractHostFromURI(t *testing.T) {
tests := []struct {
uri string
expected string
}{
{"sip:alice@example.com", "example.com"},
{"sip:bob@192.168.1.100:5060", "192.168.1.100"},
{"sips:secure@tls.example.com:5061;transport=tls", "tls.example.com"},
{"sip:user@host;param=value", "host"},
}
for _, tt := range tests {
t.Run(tt.uri, func(t *testing.T) {
host := ExtractHostFromURI(tt.uri)
if host != tt.expected {
t.Errorf("ExtractHostFromURI(%s) = %s, want %s", tt.uri, host, tt.expected)
}
})
}
}
func TestRewriteURIHost(t *testing.T) {
tests := []struct {
uri string
newHost string
newPort int
expected string
}{
{
uri: "sip:alice@192.168.1.100:5060",
newHost: "proxy.example.com",
newPort: 5060,
expected: "sip:alice@proxy.example.com",
},
{
uri: "sip:bob@internal.lan:5060;transport=udp",
newHost: "external.example.com",
newPort: 5080,
expected: "sip:bob@external.example.com:5080;transport=udp",
},
}
for _, tt := range tests {
t.Run(tt.uri, func(t *testing.T) {
result := RewriteURIHost(tt.uri, tt.newHost, tt.newPort)
if result != tt.expected {
t.Errorf("RewriteURIHost(%s, %s, %d) = %s, want %s",
tt.uri, tt.newHost, tt.newPort, result, tt.expected)
}
})
}
}
func TestGenerateBranch(t *testing.T) {
branch := GenerateBranch()
// RFC 3261: Branch must start with "z9hG4bK"
if !strings.HasPrefix(branch, "z9hG4bK") {
t.Errorf("Branch %s does not start with z9hG4bK", branch)
}
// Should be unique
branch2 := GenerateBranch()
if branch == branch2 {
t.Error("GenerateBranch should generate unique values")
}
}
// =============================================================================
// Dialog State Tests (dialog_state.go)
// =============================================================================
func TestDialogManagerCreateAndGet(t *testing.T) {
logger := zap.NewNop()
dm := NewDialogManager(logger, 5*time.Minute)
// Create a mock SIP message
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK776test\r\n" +
"From: <sip:alice@example.com>;tag=fromtag123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: testcall@example.com\r\n" +
"CSeq: 1 INVITE\r\n" +
"Contact: <sip:alice@192.168.1.100>\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
// Create dialog
state := dm.CreateDialog(msg, "192.168.1.100", 5060)
if state == nil {
t.Fatal("CreateDialog returned nil")
}
if state.CallID != "testcall@example.com" {
t.Errorf("CallID mismatch: got %s", state.CallID)
}
if state.FromTag != "fromtag123" {
t.Errorf("FromTag mismatch: got %s", state.FromTag)
}
// Get by Call-ID
retrieved := dm.GetDialogByCallID("testcall@example.com", "fromtag123")
if retrieved == nil {
t.Fatal("GetDialogByCallID returned nil")
}
if retrieved != state {
t.Error("Retrieved dialog doesn't match created dialog")
}
// Get by branch
byBranch := dm.GetDialogByBranch("z9hG4bK776test")
if byBranch == nil {
t.Fatal("GetDialogByBranch returned nil")
}
if byBranch != state {
t.Error("Retrieved dialog by branch doesn't match")
}
}
func TestDialogManagerUpdateAndConfirm(t *testing.T) {
logger := zap.NewNop()
dm := NewDialogManager(logger, 5*time.Minute)
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKupdate\r\n" +
"From: <sip:alice@example.com>;tag=from456\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: updatetest@example.com\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
dm.CreateDialog(msg, "192.168.1.100", 5060)
// Update with To-tag (from 200 OK)
dm.UpdateDialog("updatetest@example.com", "from456", "totag789")
state := dm.GetDialogByCallID("updatetest@example.com", "from456")
if state.ToTag != "totag789" {
t.Errorf("ToTag not updated: got %s", state.ToTag)
}
// Confirm dialog
dm.ConfirmDialog("updatetest@example.com", "from456")
if !state.IsConfirmed {
t.Error("Dialog not confirmed")
}
// Terminate dialog
dm.TerminateDialog("updatetest@example.com", "from456")
if !state.IsTerminated {
t.Error("Dialog not terminated")
}
}
func TestDialogManagerRemove(t *testing.T) {
logger := zap.NewNop()
dm := NewDialogManager(logger, 5*time.Minute)
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKremove\r\n" +
"From: <sip:alice@example.com>;tag=removetag\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: removetest@example.com\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
dm.CreateDialog(msg, "192.168.1.100", 5060)
// Remove dialog
dm.RemoveDialog("removetest@example.com", "removetag")
// Should be gone
if dm.GetDialogByCallID("removetest@example.com", "removetag") != nil {
t.Error("Dialog not removed from call-id index")
}
if dm.GetDialogByBranch("z9hG4bKremove") != nil {
t.Error("Dialog not removed from branch index")
}
}
func TestDialogManagerStats(t *testing.T) {
logger := zap.NewNop()
dm := NewDialogManager(logger, 5*time.Minute)
// Create multiple dialogs in different states
for i := 0; i < 3; i++ {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bKstats" + string(rune('a'+i)) + "\r\n" +
"From: <sip:alice@example.com>;tag=statstag" + string(rune('a'+i)) + "\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: statstest" + string(rune('a'+i)) + "@example.com\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
dm.CreateDialog(msg, "192.168.1.100", 5060)
}
// Confirm one
dm.ConfirmDialog("statstesta@example.com", "statstaga")
// Terminate one
dm.TerminateDialog("statstestb@example.com", "statstagb")
stats := dm.GetStats()
if stats["total_dialogs"].(int) != 3 {
t.Errorf("Expected 3 total dialogs, got %v", stats["total_dialogs"])
}
if stats["confirmed_dialogs"].(int) != 1 {
t.Errorf("Expected 1 confirmed dialog, got %v", stats["confirmed_dialogs"])
}
if stats["terminated_dialogs"].(int) != 1 {
t.Errorf("Expected 1 terminated dialog, got %v", stats["terminated_dialogs"])
}
}
func TestTransactionManager(t *testing.T) {
logger := zap.NewNop()
tm := NewTransactionManager(logger, 32*time.Second)
// Create transaction
state := tm.CreateTransaction("z9hG4bKtx123", "INVITE", "192.168.1.100", 5060, "SIP/2.0/UDP original:5060")
if state == nil {
t.Fatal("CreateTransaction returned nil")
}
// Retrieve
retrieved := tm.GetTransaction("z9hG4bKtx123")
if retrieved == nil {
t.Fatal("GetTransaction returned nil")
}
if retrieved.Method != "INVITE" {
t.Errorf("Method mismatch: got %s", retrieved.Method)
}
// Remove
tm.RemoveTransaction("z9hG4bKtx123")
if tm.GetTransaction("z9hG4bKtx123") != nil {
t.Error("Transaction not removed")
}
}
// =============================================================================
// Topology Hider Tests (topology.go)
// =============================================================================
func TestTopologyHiderConfig(t *testing.T) {
th := &TopologyHider{
Enabled: true,
ProxyHost: "proxy.example.com",
ProxyPort: 5060,
RewriteVia: true,
RewriteContact: true,
AnonymizeCallID: false,
HidePrivateIPs: true,
StripHeaders: []string{"Server", "User-Agent"},
}
if !th.Enabled {
t.Error("Expected Enabled to be true")
}
if len(th.StripHeaders) != 2 {
t.Errorf("Expected 2 strip headers, got %d", len(th.StripHeaders))
}
}
func TestIsDialogCreating(t *testing.T) {
tests := []struct {
method string
expected bool
}{
{"INVITE", true},
{"SUBSCRIBE", true},
{"REFER", true},
{"NOTIFY", true},
{"REGISTER", false},
{"OPTIONS", false},
{"BYE", false},
{"CANCEL", false},
}
for _, tt := range tests {
t.Run(tt.method, func(t *testing.T) {
raw := tt.method + " sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 " + tt.method + "\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
if msg.IsDialogCreating() != tt.expected {
t.Errorf("%s: IsDialogCreating() = %v, want %v",
tt.method, msg.IsDialogCreating(), tt.expected)
}
})
}
}
func TestIsDialogTerminating(t *testing.T) {
byeMsg := "BYE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>;tag=456\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 2 BYE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(byeMsg))
if !msg.IsDialogTerminating() {
t.Error("BYE should be dialog terminating")
}
inviteMsg := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg2, _ := ParseSIPMessage([]byte(inviteMsg))
if msg2.IsDialogTerminating() {
t.Error("INVITE should not be dialog terminating")
}
}
func TestSensitiveHeadersList(t *testing.T) {
expected := []string{
"P-Preferred-Identity",
"P-Asserted-Identity",
"Remote-Party-ID",
"P-Charging-Vector",
"P-Charging-Function-Addresses",
"Server",
"User-Agent",
"X-Asterisk-HangupCause",
"X-Asterisk-HangupCauseCode",
}
if len(SensitiveHeaders) != len(expected) {
t.Errorf("Expected %d sensitive headers, got %d", len(expected), len(SensitiveHeaders))
}
for _, h := range expected {
found := false
for _, sh := range SensitiveHeaders {
if sh == h {
found = true
break
}
}
if !found {
t.Errorf("Missing sensitive header: %s", h)
}
}
}
func TestUpdateContentLength(t *testing.T) {
raw := "INVITE sip:bob@example.com SIP/2.0\r\n" +
"Via: SIP/2.0/UDP 192.168.1.100:5060;branch=z9hG4bK123\r\n" +
"From: <sip:alice@example.com>;tag=123\r\n" +
"To: <sip:bob@example.com>\r\n" +
"Call-ID: test@host\r\n" +
"CSeq: 1 INVITE\r\n" +
"Content-Length: 0\r\n" +
"\r\n"
msg, _ := ParseSIPMessage([]byte(raw))
// Set a body
msg.Body = []byte("v=0\r\no=- 12345 12345 IN IP4 192.168.1.100\r\ns=-\r\n")
msg.UpdateContentLength()
cl := msg.GetHeader("Content-Length")
if cl == nil {
t.Fatal("Content-Length header missing")
}
if cl.Value != "48" {
t.Errorf("Content-Length should be 48, got %s", cl.Value)
}
}