Add Origin checking

This commit is contained in:
Teal Bauer 2025-04-04 16:15:29 +02:00
parent 1e737ed44b
commit e462164321

View File

@ -11,12 +11,16 @@ import sys
import threading
import time
from threading import Lock
from typing import Dict
from typing import Dict, List
from urllib.parse import quote
from urllib.parse import urlparse
import requests
from mcp.server.fastmcp import FastMCP
# Allowed origins for CORS/CSRF protection
ALLOWED_ORIGINS = os.environ.get("GHIDRA_ALLOWED_ORIGINS", "http://localhost").split(",")
# Track active Ghidra instances (port -> info dict)
active_instances: Dict[int, dict] = {}
instances_lock = Lock()
@ -51,6 +55,23 @@ def get_instance_url(port: int) -> str:
return f"http://{ghidra_host}:{port}"
def validate_origin(headers: dict) -> bool:
"""Validate request origin against allowed origins"""
origin = headers.get("Origin")
if not origin:
return True # No origin header - allow (browser same-origin policy applies)
# Parse origin to get scheme+hostname
try:
parsed = urlparse(origin)
origin_base = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
origin_base += f":{parsed.port}"
except:
return False
return origin_base in ALLOWED_ORIGINS
def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
"""Perform a GET request to a specific Ghidra instance and return JSON response"""
if params is None:
@ -58,6 +79,15 @@ def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
url = f"{get_instance_url(port)}/{endpoint}"
# Check origin if this is a state-changing request
if endpoint not in ["instances", "info"] and not validate_origin(params.get("headers", {})):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403,
"timestamp": int(time.time() * 1000)
}
try:
response = requests.get(
url,
@ -131,6 +161,14 @@ def safe_put(port: int, endpoint: str, data: dict) -> dict:
"""Perform a PUT request to a specific Ghidra instance with JSON payload"""
try:
url = f"{get_instance_url(port)}/{endpoint}"
# Always validate origin for PUT requests
if not validate_origin(data.get("headers", {})):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403
}
response = requests.put(
url,
json=data,
@ -184,6 +222,15 @@ def safe_post(port: int, endpoint: str, data: dict | str) -> dict:
try:
url = f"{get_instance_url(port)}/{endpoint}"
# Always validate origin for POST requests
headers = data.get("headers", {}) if isinstance(data, dict) else {}
if not validate_origin(headers):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403
}
if isinstance(data, dict):
response = requests.post(
url,