mcghidra/bridge_mcp_hydra.py
Teal Bauer a5c600b07f fix: Resolve MCP bridge test failures
Standardizes communication between the Python bridge and Java plugin,
resolves test logic errors, and improves error handling to ensure
MCP bridge tests pass reliably.

Key changes:
- Standardized HTTP methods: Use GET for read operations and POST for all modification operations across the bridge and plugin.
- Fixed JSON parsing in Java plugin using Gson and added missing imports.
- Corrected error handling in Java plugin's `get_function` to return `success: false` when a function is not found.
- Updated Python bridge's `safe_get` to correctly propagate nested failure responses from the plugin.
- Fixed test client logic (`test_mcp_client.py`) to correctly extract function name/address from `list_functions` results.
- Added logging to `test_mcp_client.py` for easier debugging of mutating operations.
2025-04-07 14:31:46 +02:00

788 lines
29 KiB
Python

# /// script
# requires-python = ">=3.11"
# dependencies = [
# "mcp==1.6.0",
# "requests==2.32.3",
# ]
# ///
import os
import signal
import sys
import threading
import time
from threading import Lock
from typing import Dict, List
from urllib.parse import quote
from urllib.parse import urlparse
import requests
from mcp.server.fastmcp import FastMCP
# Allowed origins for CORS/CSRF protection
ALLOWED_ORIGINS = os.environ.get("GHIDRA_ALLOWED_ORIGINS", "http://localhost").split(",")
# Track active Ghidra instances (port -> info dict)
active_instances: Dict[int, dict] = {}
instances_lock = Lock()
DEFAULT_GHIDRA_PORT = 8192
DEFAULT_GHIDRA_HOST = "localhost"
# Port ranges for scanning
QUICK_DISCOVERY_RANGE = range(8192, 8202) # Limited range for interactive/triggered discovery (10 ports)
FULL_DISCOVERY_RANGE = range(8192, 8212) # Wider range for background discovery (20 ports)
instructions = """
GhydraMCP allows interacting with multiple Ghidra SRE instances. Ghidra SRE is a tool for reverse engineering and analyzing binaries, e.g. malware.
First, run `discover_instances` to find open Ghidra instances. List tools to see what GhydraMCP can do.
"""
mcp = FastMCP("GhydraMCP", instructions=instructions)
ghidra_host = os.environ.get("GHIDRA_HYDRA_HOST", DEFAULT_GHIDRA_HOST)
# print(f"Using Ghidra host: {ghidra_host}")
def get_instance_url(port: int) -> str:
"""Get URL for a Ghidra instance by port"""
with instances_lock:
if port in active_instances:
return active_instances[port]["url"]
# Auto-register if not found but port is valid
if 8192 <= port <= 65535:
register_instance(port)
if port in active_instances:
return active_instances[port]["url"]
return f"http://{ghidra_host}:{port}"
def validate_origin(headers: dict) -> bool:
"""Validate request origin against allowed origins"""
origin = headers.get("Origin")
if not origin:
return True # No origin header - allow (browser same-origin policy applies)
# Parse origin to get scheme+hostname
try:
parsed = urlparse(origin)
origin_base = f"{parsed.scheme}://{parsed.hostname}"
if parsed.port:
origin_base += f":{parsed.port}"
except:
return False
return origin_base in ALLOWED_ORIGINS
def safe_get(port: int, endpoint: str, params: dict = None) -> dict:
"""Perform a GET request to a specific Ghidra instance and return JSON response"""
if params is None:
params = {}
url = f"{get_instance_url(port)}/{endpoint}"
# Check origin if this is a state-changing request
if endpoint not in ["instances", "info"] and not validate_origin(params.get("headers", {})):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403,
"timestamp": int(time.time() * 1000)
}
try:
response = requests.get(
url,
params=params,
headers={'Accept': 'application/json'},
timeout=5
)
if response.ok:
try:
# Always expect JSON response
json_data = response.json()
# If the response has a 'result' field that's a string, extract it
if isinstance(json_data, dict) and 'result' in json_data:
# Check if the nested data indicates failure
if isinstance(json_data.get("data"), dict) and json_data["data"].get("success") is False:
# Propagate the nested failure
return {
"success": False,
"error": json_data["data"].get("error", "Nested operation failed"),
"status_code": response.status_code, # Keep original status code if possible
"timestamp": int(time.time() * 1000)
}
return json_data # Return as is if it has 'result' or doesn't indicate nested failure
# Otherwise, wrap the response in a standard format if it's not already structured
if not isinstance(json_data, dict) or ('success' not in json_data and 'result' not in json_data):
return {
"success": True,
"data": json_data,
"timestamp": int(time.time() * 1000)
}
return json_data # Return already structured JSON as is
except ValueError:
# If not JSON, wrap the text in our standard format
return {
"success": False,
"error": "Invalid JSON response",
"response": response.text,
"timestamp": int(time.time() * 1000)
}
else:
# Try falling back to default instance if this was a secondary instance
if port != DEFAULT_GHIDRA_PORT and response.status_code == 404:
return safe_get(DEFAULT_GHIDRA_PORT, endpoint, params)
try:
error_data = response.json()
return {
"success": False,
"error": error_data.get("error", f"HTTP {response.status_code}"),
"status_code": response.status_code,
"timestamp": int(time.time() * 1000)
}
except ValueError:
return {
"success": False,
"error": response.text.strip(),
"status_code": response.status_code,
"timestamp": int(time.time() * 1000)
}
except requests.exceptions.ConnectionError:
# Instance may be down - try default instance if this was secondary
if port != DEFAULT_GHIDRA_PORT:
return safe_get(DEFAULT_GHIDRA_PORT, endpoint, params)
return {
"success": False,
"error": "Failed to connect to Ghidra instance",
"status_code": 503,
"timestamp": int(time.time() * 1000)
}
except Exception as e:
return {
"success": False,
"error": str(e),
"exception": e.__class__.__name__,
"timestamp": int(time.time() * 1000)
}
def safe_put(port: int, endpoint: str, data: dict) -> dict:
"""Perform a PUT request to a specific Ghidra instance with JSON payload"""
try:
url = f"{get_instance_url(port)}/{endpoint}"
# Always validate origin for PUT requests
if not validate_origin(data.get("headers", {})):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403
}
response = requests.put(
url,
json=data,
headers={'Content-Type': 'application/json'},
timeout=5
)
if response.ok:
try:
return response.json()
except ValueError:
return {
"success": True,
"result": response.text.strip()
}
else:
# Try falling back to default instance if this was a secondary instance
if port != DEFAULT_GHIDRA_PORT and response.status_code == 404:
return safe_put(DEFAULT_GHIDRA_PORT, endpoint, data)
try:
error_data = response.json()
return {
"success": False,
"error": error_data.get("error", f"HTTP {response.status_code}"),
"status_code": response.status_code
}
except ValueError:
return {
"success": False,
"error": response.text.strip(),
"status_code": response.status_code
}
except requests.exceptions.ConnectionError:
if port != DEFAULT_GHIDRA_PORT:
return safe_put(DEFAULT_GHIDRA_PORT, endpoint, data)
return {
"success": False,
"error": "Failed to connect to Ghidra instance",
"status_code": 503
}
except Exception as e:
return {
"success": False,
"error": str(e),
"exception": e.__class__.__name__
}
def safe_post(port: int, endpoint: str, data: dict | str) -> dict:
"""Perform a POST request to a specific Ghidra instance with JSON payload"""
try:
url = f"{get_instance_url(port)}/{endpoint}"
# Always validate origin for POST requests
headers = data.get("headers", {}) if isinstance(data, dict) else {}
if not validate_origin(headers):
return {
"success": False,
"error": "Origin not allowed",
"status_code": 403
}
if isinstance(data, dict):
response = requests.post(
url,
json=data,
headers={'Content-Type': 'application/json'},
timeout=5
)
else:
response = requests.post(
url,
data=data,
headers={'Content-Type': 'text/plain'},
timeout=5
)
if response.ok:
try:
return response.json()
except ValueError:
return {
"success": True,
"result": response.text.strip()
}
else:
# # Try falling back to default instance if this was a secondary instance
# if port != DEFAULT_GHIDRA_PORT and response.status_code == 404:
# return safe_post(DEFAULT_GHIDRA_PORT, endpoint, data)
try:
error_data = response.json()
return {
"success": False,
"error": error_data.get("error", f"HTTP {response.status_code}"),
"status_code": response.status_code
}
except ValueError:
return {
"success": False,
"error": response.text.strip(),
"status_code": response.status_code
}
except requests.exceptions.ConnectionError:
if port != DEFAULT_GHIDRA_PORT:
return safe_post(DEFAULT_GHIDRA_PORT, endpoint, data)
return {
"success": False,
"error": "Failed to connect to Ghidra instance",
"status_code": 503
}
except Exception as e:
return {
"success": False,
"error": str(e),
"exception": e.__class__.__name__
}
# Instance management tools
@mcp.tool()
def list_instances() -> dict:
"""List all active Ghidra instances"""
with instances_lock:
return {
"instances": [
{
"port": port,
"url": info["url"],
"project": info.get("project", ""),
"file": info.get("file", "")
}
for port, info in active_instances.items()
]
}
@mcp.tool()
def register_instance(port: int, url: str = None) -> str:
"""Register a new Ghidra instance"""
if url is None:
url = f"http://{ghidra_host}:{port}"
# Verify instance is reachable before registering
try:
test_url = f"{url}/instances"
response = requests.get(test_url, timeout=2)
if not response.ok:
return f"Error: Instance at {url} is not responding properly"
# Try to get project info
project_info = {"url": url}
try:
# Try the root endpoint first
root_url = f"{url}/"
root_response = requests.get(root_url, timeout=1.5) # Short timeout for root
if root_response.ok:
try:
root_data = root_response.json()
# Extract basic information from root
if "project" in root_data and root_data["project"]:
project_info["project"] = root_data["project"]
if "file" in root_data and root_data["file"]:
project_info["file"] = root_data["file"]
except Exception as e:
print(f"Error parsing root info: {e}", file=sys.stderr)
# If we don't have project info yet, try the /info endpoint as a fallback
if not project_info.get("project") and not project_info.get("file"):
info_url = f"{url}/info"
try:
info_response = requests.get(info_url, timeout=2)
if info_response.ok:
try:
info_data = info_response.json()
# Extract relevant information
if "project" in info_data and info_data["project"]:
project_info["project"] = info_data["project"]
# Handle file information
file_info = info_data.get("file", {})
if isinstance(file_info, dict) and file_info.get("name"):
project_info["file"] = file_info.get("name", "")
project_info["path"] = file_info.get("path", "")
project_info["architecture"] = file_info.get("architecture", "")
project_info["endian"] = file_info.get("endian", "")
print(f"Info data parsed: {project_info}", file=sys.stderr)
except Exception as e:
print(f"Error parsing info endpoint: {e}", file=sys.stderr)
except Exception as e:
print(f"Error connecting to info endpoint: {e}", file=sys.stderr)
except Exception:
# Non-critical, continue with registration even if project info fails
pass
with instances_lock:
active_instances[port] = project_info
return f"Registered instance on port {port} at {url}"
except Exception as e:
return f"Error: Could not connect to instance at {url}: {str(e)}"
@mcp.tool()
def unregister_instance(port: int) -> str:
"""Unregister a Ghidra instance"""
with instances_lock:
if port in active_instances:
del active_instances[port]
return f"Unregistered instance on port {port}"
return f"No instance found on port {port}"
@mcp.tool()
def discover_instances(host: str = None) -> dict:
"""Auto-discover Ghidra instances by scanning ports (quick discovery with limited range)
Args:
host: Optional host to scan (defaults to configured ghidra_host)
"""
return _discover_instances(QUICK_DISCOVERY_RANGE, host=host, timeout=0.5)
def _discover_instances(port_range, host=None, timeout=0.5) -> dict:
"""Internal function to discover Ghidra instances by scanning ports"""
found_instances = []
scan_host = host if host is not None else ghidra_host
for port in port_range:
if port in active_instances:
continue
url = f"http://{scan_host}:{port}"
try:
test_url = f"{url}/instances"
response = requests.get(test_url, timeout=timeout) # Short timeout for scanning
if response.ok:
result = register_instance(port, url)
found_instances.append({"port": port, "url": url, "result": result})
except requests.exceptions.RequestException:
# Instance not available, just continue
continue
return {
"found": len(found_instances),
"instances": found_instances
}
@mcp.tool()
def list_functions(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all functions in the current program
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of segments to return (default: 100)
Returns:
List of strings with function names and addresses
"""
return safe_get(port, "functions", {"offset": offset, "limit": limit})
@mcp.tool()
def list_classes(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all classes with pagination"""
return safe_get(port, "classes", {"offset": offset, "limit": limit})
@mcp.tool()
def get_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "") -> str:
"""Get decompiled code for a specific function"""
return safe_get(port, f"functions/{quote(name)}", {})
@mcp.tool()
def update_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "", new_name: str = "") -> str:
"""Rename a function (Modify -> POST)"""
return safe_post(port, f"functions/{quote(name)}", {"newName": new_name})
@mcp.tool()
def update_data(port: int = DEFAULT_GHIDRA_PORT, address: str = "", new_name: str = "") -> str:
"""Rename data at specified address (Modify -> POST)"""
return safe_post(port, "data", {"address": address, "newName": new_name})
@mcp.tool()
def list_segments(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all memory segments in the current program with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of segments to return (default: 100)
Returns:
List of segment information strings
"""
return safe_get(port, "segments", {"offset": offset, "limit": limit})
@mcp.tool()
def list_imports(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all imported symbols with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of imports to return (default: 100)
Returns:
List of import information strings
"""
return safe_get(port, "symbols/imports", {"offset": offset, "limit": limit})
@mcp.tool()
def list_exports(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all exported symbols with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of exports to return (default: 100)
Returns:
List of export information strings
"""
return safe_get(port, "symbols/exports", {"offset": offset, "limit": limit})
@mcp.tool()
def list_namespaces(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all namespaces in the current program with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of namespaces to return (default: 100)
Returns:
List of namespace information strings
"""
return safe_get(port, "namespaces", {"offset": offset, "limit": limit})
@mcp.tool()
def list_data_items(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list:
"""List all defined data items with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum number of data items to return (default: 100)
Returns:
List of data item information strings
"""
return safe_get(port, "data", {"offset": offset, "limit": limit})
@mcp.tool()
def search_functions_by_name(port: int = DEFAULT_GHIDRA_PORT, query: str = "", offset: int = 0, limit: int = 100) -> list:
"""Search for functions by name with pagination
Args:
port: Ghidra instance port (default: 8192)
query: Search string to match against function names
offset: Pagination offset (default: 0)
limit: Maximum number of functions to return (default: 100)
Returns:
List of matching function information strings or error message if query is empty
"""
if not query:
return ["Error: query string is required"]
return safe_get(port, "functions", {"query": query, "offset": offset, "limit": limit})
@mcp.tool()
def get_function_by_address(port: int = DEFAULT_GHIDRA_PORT, address: str = "") -> str:
"""Get function details by its memory address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address of the function (hex string)
Returns:
Multiline string with function details including name, address, and signature
"""
return "\n".join(safe_get(port, "get_function_by_address", {"address": address}))
@mcp.tool()
def get_current_address(port: int = DEFAULT_GHIDRA_PORT) -> str:
"""Get the address currently selected in Ghidra's UI
Args:
port: Ghidra instance port (default: 8192)
Returns:
String containing the current memory address (hex format)
"""
return "\n".join(safe_get(port, "get_current_address"))
@mcp.tool()
def get_current_function(port: int = DEFAULT_GHIDRA_PORT) -> str:
"""Get the function currently selected in Ghidra's UI
Args:
port: Ghidra instance port (default: 8192)
Returns:
Multiline string with function details including name, address, and signature
"""
return "\n".join(safe_get(port, "get_current_function"))
@mcp.tool()
def decompile_function_by_address(port: int = DEFAULT_GHIDRA_PORT, address: str = "") -> str:
"""Decompile a function at a specific memory address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address of the function (hex string)
Returns:
Multiline string containing the decompiled pseudocode
"""
return "\n".join(safe_get(port, "decompile_function", {"address": address}))
@mcp.tool()
def disassemble_function(port: int = DEFAULT_GHIDRA_PORT, address: str = "") -> list:
"""Get disassembly for a function at a specific address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address of the function (hex string)
Returns:
List of strings showing assembly instructions with addresses and comments
"""
return safe_get(port, "disassemble_function", {"address": address})
@mcp.tool()
def set_decompiler_comment(port: int = DEFAULT_GHIDRA_PORT, address: str = "", comment: str = "") -> str:
"""Add/edit a comment in the decompiler view at a specific address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address to place comment (hex string)
comment: Text of the comment to add
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "set_decompiler_comment", {"address": address, "comment": comment})
@mcp.tool()
def set_disassembly_comment(port: int = DEFAULT_GHIDRA_PORT, address: str = "", comment: str = "") -> str:
"""Add/edit a comment in the disassembly view at a specific address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address to place comment (hex string)
comment: Text of the comment to add
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "set_disassembly_comment", {"address": address, "comment": comment})
@mcp.tool()
def rename_local_variable(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", old_name: str = "", new_name: str = "") -> str:
"""Rename a local variable within a function
Args:
port: Ghidra instance port (default: 8192)
function_address: Memory address of the function (hex string)
old_name: Current name of the variable
new_name: New name for the variable
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "rename_local_variable", {"functionAddress": function_address, "oldName": old_name, "newName": new_name})
@mcp.tool()
def rename_function_by_address(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", new_name: str = "") -> str:
"""Rename a function at a specific memory address
Args:
port: Ghidra instance port (default: 8192)
function_address: Memory address of the function (hex string)
new_name: New name for the function
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "rename_function_by_address", {"functionAddress": function_address, "newName": new_name})
@mcp.tool()
def set_function_prototype(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", prototype: str = "") -> str:
"""Update a function's signature/prototype
Args:
port: Ghidra instance port (default: 8192)
function_address: Memory address of the function (hex string)
prototype: New function prototype string (e.g. "int func(int param1)")
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "set_function_prototype", {"functionAddress": function_address, "prototype": prototype})
@mcp.tool()
def set_local_variable_type(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", variable_name: str = "", new_type: str = "") -> str:
"""Change the data type of a local variable in a function
Args:
port: Ghidra instance port (default: 8192)
function_address: Memory address of the function (hex string)
variable_name: Name of the variable to modify
new_type: New data type for the variable (e.g. "int", "char*")
Returns:
Confirmation message or error if failed
"""
return safe_post(port, "set_local_variable_type", {"functionAddress": function_address, "variableName": variable_name, "newType": new_type})
@mcp.tool()
def list_variables(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100, search: str = "") -> list:
"""List global variables with optional search"""
params = {"offset": offset, "limit": limit}
if search:
params["search"] = search
return safe_get(port, "variables", params)
@mcp.tool()
def list_function_variables(port: int = DEFAULT_GHIDRA_PORT, function: str = "") -> str:
"""List variables in a specific function"""
if not function:
return "Error: function name is required"
encoded_name = quote(function)
return safe_get(port, f"functions/{encoded_name}/variables", {})
@mcp.tool()
def rename_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: str = "", new_name: str = "") -> str:
"""Rename a variable in a function"""
if not function or not name or not new_name:
return "Error: function, name, and new_name parameters are required"
encoded_function = quote(function)
encoded_var = quote(name)
return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"newName": new_name})
@mcp.tool()
def retype_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: str = "", data_type: str = "") -> str:
"""Change the data type of a variable in a function"""
if not function or not name or not data_type:
return "Error: function, name, and data_type parameters are required"
encoded_function = quote(function)
encoded_var = quote(name)
return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"dataType": data_type})
def handle_sigint(signum, frame):
os._exit(0)
def periodic_discovery():
"""Periodically discover new instances"""
while True:
try:
# Use the full discovery range
_discover_instances(FULL_DISCOVERY_RANGE, timeout=0.5)
# Also check if any existing instances are down
with instances_lock:
ports_to_remove = []
for port, info in active_instances.items():
url = info["url"]
try:
response = requests.get(f"{url}/instances", timeout=1)
if not response.ok:
ports_to_remove.append(port)
except requests.exceptions.RequestException:
ports_to_remove.append(port)
# Remove any instances that are down
for port in ports_to_remove:
del active_instances[port]
print(f"Removed unreachable instance on port {port}")
except Exception as e:
print(f"Error in periodic discovery: {e}")
# Sleep for 30 seconds before next scan
time.sleep(30)
if __name__ == "__main__":
# Auto-register default instance
register_instance(DEFAULT_GHIDRA_PORT, f"http://{ghidra_host}:{DEFAULT_GHIDRA_PORT}")
# Auto-discover other instances
discover_instances()
# Start periodic discovery in background thread
discovery_thread = threading.Thread(
target=periodic_discovery,
daemon=True,
name="GhydraMCP-Discovery"
)
discovery_thread.start()
signal.signal(signal.SIGINT, handle_sigint)
mcp.run(transport="stdio")