diff --git a/bridge_mcp_hydra.py b/bridge_mcp_hydra.py index f30e828..3a99e82 100644 --- a/bridge_mcp_hydra.py +++ b/bridge_mcp_hydra.py @@ -11,12 +11,16 @@ import sys import threading import time from threading import Lock -from typing import Dict +from typing import Dict, List from urllib.parse import quote +from urllib.parse import urlparse import requests from mcp.server.fastmcp import FastMCP +# Allowed origins for CORS/CSRF protection +ALLOWED_ORIGINS = os.environ.get("GHIDRA_ALLOWED_ORIGINS", "http://localhost").split(",") + # Track active Ghidra instances (port -> info dict) active_instances: Dict[int, dict] = {} instances_lock = Lock() @@ -51,12 +55,38 @@ def get_instance_url(port: int) -> str: return f"http://{ghidra_host}:{port}" +def validate_origin(headers: dict) -> bool: + """Validate request origin against allowed origins""" + origin = headers.get("Origin") + if not origin: + return True # No origin header - allow (browser same-origin policy applies) + + # Parse origin to get scheme+hostname + try: + parsed = urlparse(origin) + origin_base = f"{parsed.scheme}://{parsed.hostname}" + if parsed.port: + origin_base += f":{parsed.port}" + except: + return False + + return origin_base in ALLOWED_ORIGINS + def safe_get(port: int, endpoint: str, params: dict = None) -> dict: """Perform a GET request to a specific Ghidra instance and return JSON response""" if params is None: params = {} url = f"{get_instance_url(port)}/{endpoint}" + + # Check origin if this is a state-changing request + if endpoint not in ["instances", "info"] and not validate_origin(params.get("headers", {})): + return { + "success": False, + "error": "Origin not allowed", + "status_code": 403, + "timestamp": int(time.time() * 1000) + } try: response = requests.get( @@ -131,6 +161,14 @@ def safe_put(port: int, endpoint: str, data: dict) -> dict: """Perform a PUT request to a specific Ghidra instance with JSON payload""" try: url = f"{get_instance_url(port)}/{endpoint}" + + # Always validate origin for PUT requests + if not validate_origin(data.get("headers", {})): + return { + "success": False, + "error": "Origin not allowed", + "status_code": 403 + } response = requests.put( url, json=data, @@ -183,6 +221,15 @@ def safe_post(port: int, endpoint: str, data: dict | str) -> dict: """Perform a POST request to a specific Ghidra instance with JSON payload""" try: url = f"{get_instance_url(port)}/{endpoint}" + + # Always validate origin for POST requests + headers = data.get("headers", {}) if isinstance(data, dict) else {} + if not validate_origin(headers): + return { + "success": False, + "error": "Origin not allowed", + "status_code": 403 + } if isinstance(data, dict): response = requests.post(