package sipguardian import ( "fmt" "regexp" "strconv" "strings" ) // ViaHeader represents a parsed Via header type ViaHeader struct { Protocol string // SIP/2.0 Transport string // UDP, TCP, TLS, WS, WSS Host string Port int Branch string Received string RPort int Params map[string]string // Other parameters } // ContactHeader represents a parsed Contact header type ContactHeader struct { DisplayName string URI string Params map[string]string } // FromToHeader represents a parsed From or To header type FromToHeader struct { DisplayName string URI string Tag string Params map[string]string } // SensitiveHeaders lists headers that may leak internal topology var SensitiveHeaders = []string{ "P-Preferred-Identity", "P-Asserted-Identity", "Remote-Party-ID", "P-Charging-Vector", "P-Charging-Function-Addresses", "Server", "User-Agent", "X-Asterisk-HangupCause", "X-Asterisk-HangupCauseCode", } // ParseViaHeader parses a Via header value into structured form // Example: "SIP/2.0/UDP 192.168.1.1:5060;branch=z9hG4bK123;received=10.0.0.1" func ParseViaHeader(value string) (*ViaHeader, error) { via := &ViaHeader{ Params: make(map[string]string), } // Split by semicolons to get parameters parts := strings.Split(value, ";") if len(parts) == 0 { return nil, fmt.Errorf("empty Via header") } // First part: SIP/2.0/UDP host:port mainPart := strings.TrimSpace(parts[0]) fields := strings.Fields(mainPart) if len(fields) < 2 { return nil, fmt.Errorf("invalid Via header format") } // Parse protocol/transport protoTransport := strings.Split(fields[0], "/") if len(protoTransport) >= 2 { via.Protocol = protoTransport[0] + "/" + protoTransport[1] } if len(protoTransport) >= 3 { via.Transport = protoTransport[2] } // Parse host:port hostPort := fields[1] if colonIdx := strings.LastIndex(hostPort, ":"); colonIdx > 0 { via.Host = hostPort[:colonIdx] port, _ := strconv.Atoi(hostPort[colonIdx+1:]) via.Port = port } else { via.Host = hostPort via.Port = 5060 // Default SIP port } // Parse parameters for i := 1; i < len(parts); i++ { param := strings.TrimSpace(parts[i]) if eqIdx := strings.Index(param, "="); eqIdx > 0 { key := strings.ToLower(param[:eqIdx]) val := param[eqIdx+1:] switch key { case "branch": via.Branch = val case "received": via.Received = val case "rport": via.RPort, _ = strconv.Atoi(val) default: via.Params[key] = val } } else { // Flag parameter (no value) via.Params[strings.ToLower(param)] = "" } } return via, nil } // Serialize converts ViaHeader back to string format func (v *ViaHeader) Serialize() string { var buf strings.Builder // Protocol and transport buf.WriteString(v.Protocol) buf.WriteByte('/') buf.WriteString(v.Transport) buf.WriteByte(' ') // Host and port buf.WriteString(v.Host) if v.Port > 0 && v.Port != 5060 { buf.WriteByte(':') buf.WriteString(strconv.Itoa(v.Port)) } // Branch (required for RFC 3261 compliance) if v.Branch != "" { buf.WriteString(";branch=") buf.WriteString(v.Branch) } // Received parameter if v.Received != "" { buf.WriteString(";received=") buf.WriteString(v.Received) } // RPort parameter if v.RPort > 0 { buf.WriteString(";rport=") buf.WriteString(strconv.Itoa(v.RPort)) } // Other parameters for key, val := range v.Params { buf.WriteByte(';') buf.WriteString(key) if val != "" { buf.WriteByte('=') buf.WriteString(val) } } return buf.String() } // ParseContactHeader parses a Contact header value // Example: "\"Alice\" ;expires=3600" func ParseContactHeader(value string) (*ContactHeader, error) { contact := &ContactHeader{ Params: make(map[string]string), } // Handle * (special Contact for REGISTER) if strings.TrimSpace(value) == "*" { contact.URI = "*" return contact, nil } // Extract display name if present (quoted) if idx := strings.Index(value, "\""); idx >= 0 { endIdx := strings.Index(value[idx+1:], "\"") if endIdx > 0 { contact.DisplayName = value[idx+1 : idx+1+endIdx] value = value[idx+1+endIdx+1:] } } // Extract URI (within angle brackets or until semicolon) value = strings.TrimSpace(value) if strings.HasPrefix(value, "<") { endIdx := strings.Index(value, ">") if endIdx > 0 { contact.URI = value[1:endIdx] value = value[endIdx+1:] } } else { // No angle brackets - URI goes until semicolon or end semiIdx := strings.Index(value, ";") if semiIdx > 0 { contact.URI = strings.TrimSpace(value[:semiIdx]) value = value[semiIdx:] } else { contact.URI = strings.TrimSpace(value) value = "" } } // Parse parameters for _, part := range strings.Split(value, ";") { part = strings.TrimSpace(part) if part == "" { continue } if eqIdx := strings.Index(part, "="); eqIdx > 0 { contact.Params[strings.ToLower(part[:eqIdx])] = part[eqIdx+1:] } else if part != "" { contact.Params[strings.ToLower(part)] = "" } } return contact, nil } // Serialize converts ContactHeader back to string format func (c *ContactHeader) Serialize() string { var buf strings.Builder if c.URI == "*" { return "*" } if c.DisplayName != "" { buf.WriteByte('"') buf.WriteString(c.DisplayName) buf.WriteString("\" ") } buf.WriteByte('<') buf.WriteString(c.URI) buf.WriteByte('>') for key, val := range c.Params { buf.WriteByte(';') buf.WriteString(key) if val != "" { buf.WriteByte('=') buf.WriteString(val) } } return buf.String() } // ParseFromToHeader parses a From or To header value // Example: "\"Bob\" ;tag=abc123" func ParseFromToHeader(value string) (*FromToHeader, error) { header := &FromToHeader{ Params: make(map[string]string), } // Extract display name if present (quoted) if idx := strings.Index(value, "\""); idx >= 0 { endIdx := strings.Index(value[idx+1:], "\"") if endIdx > 0 { header.DisplayName = value[idx+1 : idx+1+endIdx] value = value[idx+1+endIdx+1:] } } // Extract URI (within angle brackets) value = strings.TrimSpace(value) if strings.HasPrefix(value, "<") { endIdx := strings.Index(value, ">") if endIdx > 0 { header.URI = value[1:endIdx] value = value[endIdx+1:] } } else { // No angle brackets semiIdx := strings.Index(value, ";") if semiIdx > 0 { header.URI = strings.TrimSpace(value[:semiIdx]) value = value[semiIdx:] } else { header.URI = strings.TrimSpace(value) value = "" } } // Parse parameters for _, part := range strings.Split(value, ";") { part = strings.TrimSpace(part) if part == "" { continue } if eqIdx := strings.Index(part, "="); eqIdx > 0 { key := strings.ToLower(part[:eqIdx]) val := part[eqIdx+1:] if key == "tag" { header.Tag = val } else { header.Params[key] = val } } } return header, nil } // Serialize converts FromToHeader back to string format func (f *FromToHeader) Serialize() string { var buf strings.Builder if f.DisplayName != "" { buf.WriteByte('"') buf.WriteString(f.DisplayName) buf.WriteString("\" ") } buf.WriteByte('<') buf.WriteString(f.URI) buf.WriteByte('>') if f.Tag != "" { buf.WriteString(";tag=") buf.WriteString(f.Tag) } for key, val := range f.Params { buf.WriteByte(';') buf.WriteString(key) if val != "" { buf.WriteByte('=') buf.WriteString(val) } } return buf.String() } // ExtractHostFromURI extracts the host part from a SIP URI // Example: "sip:user@host:5060;transport=udp" -> "host" func ExtractHostFromURI(uri string) string { // Remove scheme if idx := strings.Index(uri, ":"); idx >= 0 { uri = uri[idx+1:] } // Remove user part if idx := strings.Index(uri, "@"); idx >= 0 { uri = uri[idx+1:] } // Remove port and parameters if idx := strings.Index(uri, ":"); idx >= 0 { uri = uri[:idx] } if idx := strings.Index(uri, ";"); idx >= 0 { uri = uri[:idx] } return uri } // ExtractPortFromURI extracts the port from a SIP URI (default 5060) func ExtractPortFromURI(uri string) int { // Remove scheme and user part if idx := strings.Index(uri, "@"); idx >= 0 { uri = uri[idx+1:] } else if idx := strings.Index(uri, ":"); idx >= 0 { uri = uri[idx+1:] } // Extract port if colonIdx := strings.Index(uri, ":"); colonIdx >= 0 { portStr := uri[colonIdx+1:] if semiIdx := strings.Index(portStr, ";"); semiIdx >= 0 { portStr = portStr[:semiIdx] } port, err := strconv.Atoi(portStr) if err == nil { return port } } return 5060 // Default SIP port } // RewriteURIHost replaces the host in a SIP URI func RewriteURIHost(uri, newHost string, newPort int) string { // Pattern to match SIP URI re := regexp.MustCompile(`^(sips?:)([^@]+@)?([^:;>]+)(:\d+)?(.*)$`) matches := re.FindStringSubmatch(uri) if matches == nil { return uri } // Rebuild URI with new host result := matches[1] + matches[2] + newHost if newPort > 0 && newPort != 5060 { result += ":" + strconv.Itoa(newPort) } result += matches[5] return result } // GenerateBranch generates an RFC 3261 compliant branch parameter func GenerateBranch() string { // Must start with z9hG4bK for RFC 3261 compliance return "z9hG4bK" + generateRandomHex(16) } // generateRandomHex generates a random hex string of given length func generateRandomHex(length int) string { const chars = "0123456789abcdef" result := make([]byte, length) for i := range result { // Use uint64 modulo to avoid negative index from int conversion result[i] = chars[pseudoRand()%uint64(len(chars))] } return string(result) } // Simple pseudo-random number generator (not cryptographically secure) var prngState uint64 = 0xdeadbeef func pseudoRand() uint64 { prngState ^= prngState << 13 prngState ^= prngState >> 7 prngState ^= prngState << 17 return prngState } // InitPRNG seeds the pseudo-random number generator func InitPRNG(seed uint64) { prngState = seed if prngState == 0 { prngState = 0xdeadbeef } } // Note: IsPrivateIP is defined in geoip.go with proper CIDR parsing // NormalizeHeaderName converts header name to canonical form func NormalizeHeaderName(name string) string { // Map compact forms to full names compactToFull := map[string]string{ "i": "Call-ID", "m": "Contact", "e": "Content-Encoding", "l": "Content-Length", "c": "Content-Type", "f": "From", "s": "Subject", "k": "Supported", "t": "To", "v": "Via", } lower := strings.ToLower(name) if full, ok := compactToFull[lower]; ok { return full } // Title case for standard headers return strings.Title(lower) }