finalize HATEOAS updates

This commit is contained in:
Teal Bauer 2025-04-14 11:25:22 +02:00
parent 4268d3e2c5
commit 9b19011b7d
6 changed files with 761 additions and 68 deletions

View File

@ -82,7 +82,13 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None,
headers: dict = None) -> dict: headers: dict = None) -> dict:
"""Internal helper to make HTTP requests and handle common errors.""" """Internal helper to make HTTP requests and handle common errors."""
url = f"{get_instance_url(port)}/{endpoint}" url = f"{get_instance_url(port)}/{endpoint}"
request_headers = {'Accept': 'application/json'}
# Set up headers according to HATEOAS API expected format
request_headers = {
'Accept': 'application/json',
'X-Request-ID': f"mcp-bridge-{int(time.time() * 1000)}"
}
if headers: if headers:
request_headers.update(headers) request_headers.update(headers)
@ -93,7 +99,10 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None,
if not validate_origin(check_headers): if not validate_origin(check_headers):
return { return {
"success": False, "success": False,
"error": "Origin not allowed", "error": {
"code": "ORIGIN_NOT_ALLOWED",
"message": "Origin not allowed for state-changing request"
},
"status_code": 403, "status_code": 403,
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -115,15 +124,32 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None,
try: try:
parsed_json = response.json() parsed_json = response.json()
# Add timestamp if not present # Add timestamp if not present
if isinstance(parsed_json, dict) and "timestamp" not in parsed_json: if isinstance(parsed_json, dict) and "timestamp" not in parsed_json:
parsed_json["timestamp"] = int(time.time() * 1000) parsed_json["timestamp"] = int(time.time() * 1000)
# Check for HATEOAS compliant error response format and reformat if needed
if not response.ok and isinstance(parsed_json, dict) and "success" in parsed_json and not parsed_json["success"]:
# Check if error is in the expected HATEOAS format
if "error" in parsed_json and not isinstance(parsed_json["error"], dict):
# Convert string error to the proper format
error_message = parsed_json["error"]
parsed_json["error"] = {
"code": f"HTTP_{response.status_code}",
"message": error_message
}
return parsed_json return parsed_json
except ValueError: except ValueError:
if response.ok: if response.ok:
return { return {
"success": False, "success": False,
"error": "Received non-JSON success response from Ghidra plugin", "error": {
"code": "NON_JSON_RESPONSE",
"message": "Received non-JSON success response from Ghidra plugin"
},
"status_code": response.status_code, "status_code": response.status_code,
"response_text": response.text[:500], "response_text": response.text[:500],
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
@ -131,7 +157,10 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None,
else: else:
return { return {
"success": False, "success": False,
"error": f"HTTP {response.status_code} - Non-JSON error response", "error": {
"code": f"HTTP_{response.status_code}",
"message": f"Non-JSON error response: {response.text[:100]}..."
},
"status_code": response.status_code, "status_code": response.status_code,
"response_text": response.text[:500], "response_text": response.text[:500],
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
@ -140,21 +169,30 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None,
except requests.exceptions.Timeout: except requests.exceptions.Timeout:
return { return {
"success": False, "success": False,
"error": "Request timed out", "error": {
"code": "REQUEST_TIMEOUT",
"message": "Request timed out"
},
"status_code": 408, "status_code": 408,
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
except requests.exceptions.ConnectionError: except requests.exceptions.ConnectionError:
return { return {
"success": False, "success": False,
"error": f"Failed to connect to Ghidra instance at {url}", "error": {
"code": "CONNECTION_ERROR",
"message": f"Failed to connect to Ghidra instance at {url}"
},
"status_code": 503, "status_code": 503,
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
except Exception as e: except Exception as e:
return { return {
"success": False, "success": False,
"error": f"An unexpected error occurred: {str(e)}", "error": {
"code": "UNEXPECTED_ERROR",
"message": f"An unexpected error occurred: {str(e)}"
},
"exception": e.__class__.__name__, "exception": e.__class__.__name__,
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -211,6 +249,12 @@ def simplify_response(response: dict) -> dict:
# Make a copy to avoid modifying the original # Make a copy to avoid modifying the original
result = response.copy() result = response.copy()
# Store API response metadata
api_metadata = {}
for key in ["id", "instance", "timestamp", "size", "offset", "limit"]:
if key in result:
api_metadata[key] = result.get(key)
# Simplify the main result data if present # Simplify the main result data if present
if "result" in result: if "result" in result:
# Handle array results # Handle array results
@ -218,9 +262,17 @@ def simplify_response(response: dict) -> dict:
simplified_items = [] simplified_items = []
for item in result["result"]: for item in result["result"]:
if isinstance(item, dict): if isinstance(item, dict):
# Remove HATEOAS links from individual items # Store but remove HATEOAS links from individual items
item_copy = item.copy() item_copy = item.copy()
item_copy.pop("_links", None) links = item_copy.pop("_links", None)
# Optionally store direct href links as more accessible properties
# This helps AI agents navigate the API without understanding HATEOAS
if isinstance(links, dict):
for link_name, link_data in links.items():
if isinstance(link_data, dict) and "href" in link_data:
item_copy[f"{link_name}_url"] = link_data["href"]
simplified_items.append(item_copy) simplified_items.append(item_copy)
else: else:
simplified_items.append(item) simplified_items.append(item)
@ -229,8 +281,15 @@ def simplify_response(response: dict) -> dict:
# Handle object results # Handle object results
elif isinstance(result["result"], dict): elif isinstance(result["result"], dict):
result_copy = result["result"].copy() result_copy = result["result"].copy()
# Remove links from result object
result_copy.pop("_links", None) # Store but remove links from result object
links = result_copy.pop("_links", None)
# Add direct href links for easier navigation
if isinstance(links, dict):
for link_name, link_data in links.items():
if isinstance(link_data, dict) and "href" in link_data:
result_copy[f"{link_name}_url"] = link_data["href"]
# Special case for disassembly - convert to text for easier consumption # Special case for disassembly - convert to text for easier consumption
if "instructions" in result_copy and isinstance(result_copy["instructions"], list): if "instructions" in result_copy and isinstance(result_copy["instructions"], list):
@ -256,8 +315,24 @@ def simplify_response(response: dict) -> dict:
result["result"] = result_copy result["result"] = result_copy
# Remove HATEOAS links from the top level # Store but remove HATEOAS links from the top level
result.pop("_links", None) links = result.pop("_links", None)
# Add direct href links in a more accessible format
if isinstance(links, dict):
api_links = {}
for link_name, link_data in links.items():
if isinstance(link_data, dict) and "href" in link_data:
api_links[link_name] = link_data["href"]
# Add simplified links
if api_links:
result["api_links"] = api_links
# Restore API metadata
for key, value in api_metadata.items():
if key not in result:
result[key] = value
return result return result
@ -310,6 +385,17 @@ def register_instance(port: int, url: str = None) -> str:
project_info = {"url": url} project_info = {"url": url}
try: try:
# Check plugin version to ensure compatibility
try:
version_data = response.json()
if "result" in version_data:
result = version_data["result"]
if isinstance(result, dict):
project_info["plugin_version"] = result.get("plugin_version", "")
project_info["api_version"] = result.get("api_version", 0)
except Exception as e:
print(f"Error parsing plugin version: {e}", file=sys.stderr)
# Get program info from HATEOAS API # Get program info from HATEOAS API
info_url = f"{url}/program" info_url = f"{url}/program"
@ -321,12 +407,27 @@ def register_instance(port: int, url: str = None) -> str:
if "result" in info_data: if "result" in info_data:
result = info_data["result"] result = info_data["result"]
if isinstance(result, dict): if isinstance(result, dict):
project_info["project"] = result.get("project", "") # Extract project and file from programId (format: "project:/file")
program_id = result.get("programId", "")
if ":" in program_id:
project_name, file_path = program_id.split(":", 1)
project_info["project"] = project_name
# Remove leading slash from file path if present
if file_path.startswith("/"):
file_path = file_path[1:]
project_info["path"] = file_path
# Get file name directly from the result
project_info["file"] = result.get("name", "") project_info["file"] = result.get("name", "")
project_info["path"] = result.get("path", "")
project_info["language_id"] = result.get("language_id", "") # Get other metadata
project_info["compiler_spec_id"] = result.get("compiler_spec_id", "") project_info["language_id"] = result.get("languageId", "")
project_info["compiler_spec_id"] = result.get("compilerSpecId", "")
project_info["image_base"] = result.get("image_base", "") project_info["image_base"] = result.get("image_base", "")
# Store _links from result for HATEOAS navigation
if "_links" in result:
project_info["_links"] = result.get("_links", {})
except Exception as e: except Exception as e:
print(f"Error parsing info endpoint: {e}", file=sys.stderr) print(f"Error parsing info endpoint: {e}", file=sys.stderr)
except Exception as e: except Exception as e:
@ -386,11 +487,47 @@ def _discover_instances(port_range, host=None, timeout=0.5) -> dict:
try: try:
# Try HATEOAS API via plugin-version endpoint # Try HATEOAS API via plugin-version endpoint
test_url = f"{url}/plugin-version" test_url = f"{url}/plugin-version"
response = requests.get(test_url, timeout=timeout) response = requests.get(test_url,
headers={'Accept': 'application/json',
'X-Request-ID': f"discovery-{int(time.time() * 1000)}"},
timeout=timeout)
if response.ok: if response.ok:
# Further validate it's a GhydraMCP instance by checking response format
try:
json_data = response.json()
if "success" in json_data and json_data["success"] and "result" in json_data:
# Looks like a valid HATEOAS API response
# Instead of relying only on register_instance, which already checks program info,
# extract additional information here for more detailed discovery results
result = register_instance(port, url) result = register_instance(port, url)
found_instances.append(
{"port": port, "url": url, "result": result}) # Initialize report info
instance_info = {
"port": port,
"url": url
}
# Extract version info for reporting
if isinstance(json_data["result"], dict):
instance_info["plugin_version"] = json_data["result"].get("plugin_version", "unknown")
instance_info["api_version"] = json_data["result"].get("api_version", "unknown")
else:
instance_info["plugin_version"] = "unknown"
instance_info["api_version"] = "unknown"
# Include project details from registered instance in the report
if port in active_instances:
instance_info["project"] = active_instances[port].get("project", "")
instance_info["file"] = active_instances[port].get("file", "")
instance_info["result"] = result
found_instances.append(instance_info)
except (ValueError, KeyError):
# Not a valid JSON response or missing expected keys
print(f"Port {port} returned non-HATEOAS response", file=sys.stderr)
continue
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
# Instance not available, just continue # Instance not available, just continue
continue continue
@ -574,12 +711,15 @@ def disassemble_function(port: int = DEFAULT_GHIDRA_PORT,
name: Function name (mutually exclusive with address) name: Function name (mutually exclusive with address)
Returns: Returns:
dict: Contains function information and disassembly text dict: Contains function information and disassembly text, optimized for agent consumption
""" """
if not address and not name: if not address and not name:
return { return {
"success": False, "success": False,
"error": "Either address or name parameter is required", "error": {
"code": "MISSING_PARAMETER",
"message": "Either address or name parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -591,11 +731,21 @@ def disassemble_function(port: int = DEFAULT_GHIDRA_PORT,
response = safe_get(port, endpoint) response = safe_get(port, endpoint)
simplified = simplify_response(response) simplified = simplify_response(response)
# For AI consumption, add a plain text version of the disassembly if not already present # For AI consumption, create a simplified response with just the disassembly text
if "result" in simplified and isinstance(simplified["result"], dict): if "result" in simplified and isinstance(simplified["result"], dict):
if "instructions" in simplified["result"] and isinstance(simplified["result"]["instructions"], list): result = simplified["result"]
if "disassembly_text" not in simplified["result"]: function_info = None
instr_list = simplified["result"]["instructions"] disasm_text = None
# Extract function info if available
if "function" in result and isinstance(result["function"], dict):
function_info = result["function"]
# Get the disassembly text, generate it if it doesn't exist
if "disassembly_text" in result:
disasm_text = result["disassembly_text"]
elif "instructions" in result and isinstance(result["instructions"], list):
instr_list = result["instructions"]
disasm_text = "" disasm_text = ""
for instr in instr_list: for instr in instr_list:
if isinstance(instr, dict): if isinstance(instr, dict):
@ -607,10 +757,31 @@ def disassemble_function(port: int = DEFAULT_GHIDRA_PORT,
# Format: address: bytes mnemonic operands # Format: address: bytes mnemonic operands
disasm_text += f"{addr}: {bytes_str.ljust(10)} {mnemonic} {operands}\n" disasm_text += f"{addr}: {bytes_str.ljust(10)} {mnemonic} {operands}\n"
simplified["result"]["disassembly_text"] = disasm_text # Create a simplified result that's easier for agents to consume
# Also make it more directly accessible if disasm_text:
simplified["disassembly_text"] = disasm_text # Create a new response with just the important info
new_response = {
"success": True,
"id": simplified.get("id", ""),
"instance": simplified.get("instance", ""),
"timestamp": simplified.get("timestamp", int(time.time() * 1000)),
"disassembly": disasm_text # Direct access to disassembly text
}
# Add function info if available
if function_info:
new_response["function_name"] = function_info.get("name", "")
new_response["function_address"] = function_info.get("address", "")
if "signature" in function_info:
new_response["function_signature"] = function_info.get("signature", "")
# Preserve API links if available
if "api_links" in simplified:
new_response["api_links"] = simplified["api_links"]
return new_response
# If we couldn't extract disassembly text, return the original response
return simplified return simplified
@ -631,7 +802,10 @@ def get_function_variables(port: int = DEFAULT_GHIDRA_PORT,
if not address and not name: if not address and not name:
return { return {
"success": False, "success": False,
"error": "Either address or name parameter is required", "error": {
"code": "MISSING_PARAMETER",
"message": "Either address or name parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -737,6 +911,50 @@ def list_symbols(port: int = DEFAULT_GHIDRA_PORT,
return simplified return simplified
@mcp.tool()
def list_variables(port: int = DEFAULT_GHIDRA_PORT,
offset: int = 0,
limit: int = 100,
search: str = None,
global_only: bool = False) -> dict:
"""List all variables in the program with pagination
Args:
port: Ghidra instance port (default: 8192)
offset: Pagination offset (default: 0)
limit: Maximum items to return (default: 100)
search: Optional search term to filter variables by name
global_only: If True, only return global variables (default: False)
Returns:
dict: Contains list of variables with metadata and pagination info
"""
params = {
"offset": offset,
"limit": limit
}
if search:
params["search"] = search
if global_only:
params["global_only"] = str(global_only).lower()
response = safe_get(port, "variables", params)
simplified = simplify_response(response)
# Ensure we maintain pagination metadata
if isinstance(simplified, dict) and "error" not in simplified:
result_size = 0
if "result" in simplified and isinstance(simplified["result"], list):
result_size = len(simplified["result"])
simplified.setdefault("size", result_size)
simplified.setdefault("offset", offset)
simplified.setdefault("limit", limit)
return simplified
@mcp.tool() @mcp.tool()
def list_data_items(port: int = DEFAULT_GHIDRA_PORT, def list_data_items(port: int = DEFAULT_GHIDRA_PORT,
offset: int = 0, offset: int = 0,
@ -791,7 +1009,7 @@ def list_data_items(port: int = DEFAULT_GHIDRA_PORT,
@mcp.tool() @mcp.tool()
def read_memory(port: int = DEFAULT_GHIDRA_PORT, def read_memory(port: int = DEFAULT_GHIDRA_PORT,
address: str = "", address: str = None,
length: int = 16, length: int = 16,
format: str = "hex") -> dict: format: str = "hex") -> dict:
"""Read bytes from memory """Read bytes from memory
@ -807,35 +1025,50 @@ def read_memory(port: int = DEFAULT_GHIDRA_PORT,
"address": original address, "address": original address,
"length": bytes read, "length": bytes read,
"format": output format, "format": output format,
"bytes": the memory contents as a string in the specified format, "hexBytes": the memory contents as hex string,
"rawBytes": the memory contents as base64 string,
"timestamp": response timestamp "timestamp": response timestamp
} }
""" """
if not address: if not address:
return { return {
"success": False, "success": False,
"error": "Address parameter is required", "error": {
"code": "MISSING_PARAMETER",
"message": "Address parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
# Use query parameters instead of path parameters for more reliable handling
params = { params = {
"address": address,
"length": length, "length": length,
"format": format "format": format
} }
response = safe_get(port, f"memory/{address}", params) response = safe_get(port, "memory", params)
simplified = simplify_response(response) simplified = simplify_response(response)
# Ensure the result is simple and directly usable # Ensure the result is simple and directly usable
if "result" in simplified and isinstance(simplified["result"], dict): if "result" in simplified and isinstance(simplified["result"], dict):
bytes_data = simplified["result"].get("bytes", "") result = simplified["result"]
# Pass through all representations of the bytes
memory_info = { memory_info = {
"address": address, "success": True,
"length": length, "address": result.get("address", address),
"length": result.get("bytesRead", length),
"format": format, "format": format,
"bytes": bytes_data,
"timestamp": simplified.get("timestamp", int(time.time() * 1000)) "timestamp": simplified.get("timestamp", int(time.time() * 1000))
} }
# Include all the different byte representations
if "hexBytes" in result:
memory_info["hexBytes"] = result["hexBytes"]
if "rawBytes" in result:
memory_info["rawBytes"] = result["rawBytes"]
return memory_info return memory_info
return simplified return simplified
@ -843,8 +1076,8 @@ def read_memory(port: int = DEFAULT_GHIDRA_PORT,
@mcp.tool() @mcp.tool()
def write_memory(port: int = DEFAULT_GHIDRA_PORT, def write_memory(port: int = DEFAULT_GHIDRA_PORT,
address: str = "", address: str = None,
bytes_data: str = "", bytes_data: str = None,
format: str = "hex") -> dict: format: str = "hex") -> dict:
"""Write bytes to memory (use with caution) """Write bytes to memory (use with caution)
@ -860,10 +1093,23 @@ def write_memory(port: int = DEFAULT_GHIDRA_PORT,
- length: number of bytes written - length: number of bytes written
- bytesWritten: confirmation of bytes written - bytesWritten: confirmation of bytes written
""" """
if not address or not bytes_data: if not address:
return { return {
"success": False, "success": False,
"error": "Address and bytes parameters are required", "error": {
"code": "MISSING_PARAMETER",
"message": "Address parameter is required"
},
"timestamp": int(time.time() * 1000)
}
if not bytes_data:
return {
"success": False,
"error": {
"code": "MISSING_PARAMETER",
"message": "Bytes parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -906,7 +1152,10 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT,
if not to_addr and not from_addr: if not to_addr and not from_addr:
return { return {
"success": False, "success": False,
"error": "Either to_addr or from_addr parameter is required", "error": {
"code": "MISSING_PARAMETER",
"message": "Either to_addr or from_addr parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -947,11 +1196,21 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT,
flat_ref["from_function"] = ref["from_function"].get("name") flat_ref["from_function"] = ref["from_function"].get("name")
flat_ref["from_function_addr"] = ref["from_function"].get("address") flat_ref["from_function_addr"] = ref["from_function"].get("address")
# Add navigational URLs for HATEOAS
if ref["from_function"].get("address"):
flat_ref["from_function_decompile_url"] = f"functions/{ref['from_function'].get('address')}/decompile"
flat_ref["from_function_disassembly_url"] = f"functions/{ref['from_function'].get('address')}/disassembly"
# Add target function info if available # Add target function info if available
if "to_function" in ref and isinstance(ref["to_function"], dict): if "to_function" in ref and isinstance(ref["to_function"], dict):
flat_ref["to_function"] = ref["to_function"].get("name") flat_ref["to_function"] = ref["to_function"].get("name")
flat_ref["to_function_addr"] = ref["to_function"].get("address") flat_ref["to_function_addr"] = ref["to_function"].get("address")
# Add navigational URLs for HATEOAS
if ref["to_function"].get("address"):
flat_ref["to_function_decompile_url"] = f"functions/{ref['to_function'].get('address')}/decompile"
flat_ref["to_function_disassembly_url"] = f"functions/{ref['to_function'].get('address')}/disassembly"
# Add symbol info if available # Add symbol info if available
if "from_symbol" in ref: if "from_symbol" in ref:
flat_ref["from_symbol"] = ref["from_symbol"] flat_ref["from_symbol"] = ref["from_symbol"]
@ -964,6 +1223,12 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT,
if "to_instruction" in ref: if "to_instruction" in ref:
flat_ref["to_instruction"] = ref["to_instruction"] flat_ref["to_instruction"] = ref["to_instruction"]
# Add other useful HATEOAS links
if flat_ref.get("from_addr"):
flat_ref["from_memory_url"] = f"memory/{flat_ref['from_addr']}"
if flat_ref.get("to_addr"):
flat_ref["to_memory_url"] = f"memory/{flat_ref['to_addr']}"
flat_refs.append(flat_ref) flat_refs.append(flat_ref)
# Add the simplified references # Add the simplified references
@ -980,6 +1245,28 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT,
simplified["xrefs_text"] = "\n".join(text_refs) simplified["xrefs_text"] = "\n".join(text_refs)
# Add navigation links for next/previous pages
if offset > 0:
prev_offset = max(0, offset - limit)
simplified["prev_page_url"] = f"xrefs?offset={prev_offset}&limit={limit}"
if to_addr:
simplified["prev_page_url"] += f"&to_addr={to_addr}"
if from_addr:
simplified["prev_page_url"] += f"&from_addr={from_addr}"
if type:
simplified["prev_page_url"] += f"&type={type}"
total_size = simplified.get("size", 0)
if offset + limit < total_size:
next_offset = offset + limit
simplified["next_page_url"] = f"xrefs?offset={next_offset}&limit={limit}"
if to_addr:
simplified["next_page_url"] += f"&to_addr={to_addr}"
if from_addr:
simplified["next_page_url"] += f"&from_addr={from_addr}"
if type:
simplified["next_page_url"] += f"&type={type}"
return simplified return simplified
@ -1086,7 +1373,10 @@ def rename_function(port: int = DEFAULT_GHIDRA_PORT,
if not (address or name) or not new_name: if not (address or name) or not new_name:
return { return {
"success": False, "success": False,
"error": "Either address or name, and new_name parameters are required", "error": {
"code": "MISSING_PARAMETER",
"message": "Either address or name, and new_name parameters are required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -1343,7 +1633,7 @@ def get_dataflow(port: int = DEFAULT_GHIDRA_PORT,
@mcp.tool() @mcp.tool()
def set_comment(port: int = DEFAULT_GHIDRA_PORT, def set_comment(port: int = DEFAULT_GHIDRA_PORT,
address: str = "", address: str = None,
comment: str = "", comment: str = "",
comment_type: str = "plate") -> dict: comment_type: str = "plate") -> dict:
"""Set a comment at the specified address """Set a comment at the specified address
@ -1366,7 +1656,10 @@ def set_comment(port: int = DEFAULT_GHIDRA_PORT,
if not address: if not address:
return { return {
"success": False, "success": False,
"error": "Address parameter is required", "error": {
"code": "MISSING_PARAMETER",
"message": "Address parameter is required"
},
"timestamp": int(time.time() * 1000) "timestamp": int(time.time() * 1000)
} }
@ -1378,6 +1671,48 @@ def set_comment(port: int = DEFAULT_GHIDRA_PORT,
return simplify_response(response) return simplify_response(response)
@mcp.tool()
def set_decompiler_comment(port: int = DEFAULT_GHIDRA_PORT,
address: str = None,
comment: str = "") -> dict:
"""Set a decompiler comment at the specified address
Args:
port: Ghidra instance port (default: 8192)
address: Memory address in hex format
comment: Comment text
Returns:
dict: Operation result
"""
if not address:
return {
"success": False,
"error": {
"code": "MISSING_PARAMETER",
"message": "Address parameter is required"
},
"timestamp": int(time.time() * 1000)
}
# Decompiler comments are typically "plate" comments in Ghidra
payload = {
"comment": comment
}
# First try to post to the more specific decompiler endpoint if it exists
try:
response = safe_post(port, f"functions/{address}/comments", payload)
if response.get("success", False):
return simplify_response(response)
except Exception as e:
# Fall back to the general memory comments endpoint
pass
# Fall back to the normal comment mechanism with "plate" type
return set_comment(port, address, comment, "plate")
def handle_sigint(signum, frame): def handle_sigint(signum, frame):
os._exit(0) os._exit(0)
@ -1397,6 +1732,41 @@ def periodic_discovery():
response = requests.get(f"{url}/plugin-version", timeout=1) response = requests.get(f"{url}/plugin-version", timeout=1)
if not response.ok: if not response.ok:
ports_to_remove.append(port) ports_to_remove.append(port)
continue
# Update program info if available (especially to get project name)
try:
info_url = f"{url}/program"
info_response = requests.get(info_url, timeout=1)
if info_response.ok:
try:
info_data = info_response.json()
if "result" in info_data:
result = info_data["result"]
if isinstance(result, dict):
# Extract project and file from programId (format: "project:/file")
program_id = result.get("programId", "")
if ":" in program_id:
project_name, file_path = program_id.split(":", 1)
info["project"] = project_name
# Remove leading slash from file path if present
if file_path.startswith("/"):
file_path = file_path[1:]
info["path"] = file_path
# Get file name directly from the result
info["file"] = result.get("name", "")
# Get other metadata
info["language_id"] = result.get("languageId", "")
info["compiler_spec_id"] = result.get("compilerSpecId", "")
info["image_base"] = result.get("image_base", "")
except Exception as e:
print(f"Error parsing info endpoint during discovery: {e}", file=sys.stderr)
except Exception:
# Non-critical, continue even if update fails
pass
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
ports_to_remove.append(port) ports_to_remove.append(port)

View File

@ -33,6 +33,9 @@ public class AnalysisEndpoints extends AbstractEndpoint {
@Override @Override
public void registerEndpoints(HttpServer server) { public void registerEndpoints(HttpServer server) {
server.createContext("/analysis", this::handleAnalysisRequest); server.createContext("/analysis", this::handleAnalysisRequest);
// NOTE: The callgraph endpoint is now registered in ProgramEndpoints
// This comment is to avoid confusion during future maintenance
} }
private void handleAnalysisRequest(HttpExchange exchange) throws IOException { private void handleAnalysisRequest(HttpExchange exchange) throws IOException {

View File

@ -406,10 +406,22 @@ public class FunctionEndpoints extends AbstractEndpoint {
} }
if (signature != null && !signature.isEmpty()) { if (signature != null && !signature.isEmpty()) {
// Update signature - placeholder // Update function signature using our utility method
sendErrorResponse(exchange, 501, "Updating function signature not implemented", "NOT_IMPLEMENTED"); try {
boolean success = TransactionHelper.executeInTransaction(program, "Set Function Signature", () -> {
return GhidraUtil.setFunctionSignature(function, signature);
});
if (!success) {
sendErrorResponse(exchange, 400, "Failed to set function signature: invalid signature format", "SIGNATURE_FAILED");
return; return;
} }
changed = true;
} catch (Exception e) {
sendErrorResponse(exchange, 400, "Failed to set function signature: " + e.getMessage(), "SIGNATURE_FAILED");
return;
}
}
if (comment != null) { if (comment != null) {
// Update comment // Update comment
@ -830,10 +842,22 @@ public class FunctionEndpoints extends AbstractEndpoint {
} }
if (signature != null && !signature.isEmpty()) { if (signature != null && !signature.isEmpty()) {
// Update signature // Update function signature using our utility method
sendErrorResponse(exchange, 501, "Updating function signature not implemented", "NOT_IMPLEMENTED"); try {
boolean success = TransactionHelper.executeInTransaction(program, "Set Function Signature", () -> {
return GhidraUtil.setFunctionSignature(function, signature);
});
if (!success) {
sendErrorResponse(exchange, 400, "Failed to set function signature: invalid signature format", "SIGNATURE_FAILED");
return; return;
} }
changed = true;
} catch (Exception e) {
sendErrorResponse(exchange, 400, "Failed to set function signature: " + e.getMessage(), "SIGNATURE_FAILED");
return;
}
}
if (comment != null) { if (comment != null) {
// Update comment // Update comment

View File

@ -5,11 +5,13 @@ import com.google.gson.JsonObject;
import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer; import com.sun.net.httpserver.HttpServer;
import eu.starsong.ghidra.api.ResponseBuilder; import eu.starsong.ghidra.api.ResponseBuilder;
import eu.starsong.ghidra.util.TransactionHelper;
import ghidra.program.model.address.Address; import ghidra.program.model.address.Address;
import ghidra.program.model.address.AddressFactory; import ghidra.program.model.address.AddressFactory;
import ghidra.program.model.mem.Memory; import ghidra.program.model.mem.Memory;
import ghidra.program.model.mem.MemoryAccessException; import ghidra.program.model.mem.MemoryAccessException;
import ghidra.program.model.mem.MemoryBlock; import ghidra.program.model.mem.MemoryBlock;
import ghidra.program.model.listing.CodeUnit;
import ghidra.program.model.listing.Program; import ghidra.program.model.listing.Program;
import ghidra.framework.plugintool.PluginTool; import ghidra.framework.plugintool.PluginTool;
import ghidra.util.Msg; import ghidra.util.Msg;
@ -39,8 +41,25 @@ public class MemoryEndpoints extends AbstractEndpoint {
@Override @Override
public void registerEndpoints(HttpServer server) { public void registerEndpoints(HttpServer server) {
// Per HttpServer docs: paths are matched by longest matching prefix
// So register specific endpoints first, then more general ones
// Comments endpoint path needs to be registered with a specific context path
// Example: /memory/0x1000/comments/plate needs a specific handler
server.createContext("/memory/", exchange -> {
String path = exchange.getRequestURI().getPath();
if (path.contains("/comments/")) {
handleMemoryAddressRequest(exchange);
} else if (path.equals("/memory/blocks")) {
handleMemoryBlocksRequest(exchange);
} else {
// Handle as general memory address request
handleMemoryAddressRequest(exchange);
}
});
// Register the most general endpoint last
server.createContext("/memory", this::handleMemoryRequest); server.createContext("/memory", this::handleMemoryRequest);
server.createContext("/memory/blocks", this::handleMemoryBlocksRequest);
} }
private void handleMemoryRequest(HttpExchange exchange) throws IOException { private void handleMemoryRequest(HttpExchange exchange) throws IOException {
@ -170,6 +189,189 @@ public class MemoryEndpoints extends AbstractEndpoint {
} }
} }
/**
* Handle requests to /memory/{address} including child resources like comments
*/
private void handleMemoryAddressRequest(HttpExchange exchange) throws IOException {
try {
// Extract address from path: /memory/{address}/...
String path = exchange.getRequestURI().getPath();
if (path.equals("/memory/") || path.equals("/memory")) {
handleMemoryRequest(exchange);
return;
}
// Parse address from path
String remainingPath = path.substring("/memory/".length());
// Check if this is a request for a specific address's comments
if (remainingPath.contains("/comments/")) {
// Format: /memory/{address}/comments/{comment_type}
String[] parts = remainingPath.split("/comments/", 2);
String addressStr = parts[0];
String commentType = parts.length > 1 ? parts[1] : "plate"; // Default to plate comments
handleMemoryComments(exchange, addressStr, commentType);
return;
}
// Otherwise, treat as a direct memory request with address in the path
String addressStr = remainingPath;
Map<String, String> params = parseQueryParams(exchange);
// Handle same as the query parameter version
params.put("address", addressStr);
exchange.setAttribute("address", addressStr);
// Delegate to the main memory handler
handleMemoryRequest(exchange);
} catch (Exception e) {
Msg.error(this, "Error handling memory address endpoint", e);
sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage(), "INTERNAL_ERROR");
}
}
/**
* Handle requests to set or get comments at a specific memory address
*/
private void handleMemoryComments(HttpExchange exchange, String addressStr, String commentType) throws IOException {
try {
String method = exchange.getRequestMethod();
Program program = getCurrentProgram();
if (program == null) {
sendErrorResponse(exchange, 400, "No program loaded", "NO_PROGRAM_LOADED");
return;
}
// Parse address
AddressFactory addressFactory = program.getAddressFactory();
Address address;
try {
address = addressFactory.getAddress(addressStr);
} catch (Exception e) {
sendErrorResponse(exchange, 400, "Invalid address format: " + addressStr, "INVALID_ADDRESS");
return;
}
// Validate comment type
if (!isValidCommentType(commentType)) {
sendErrorResponse(exchange, 400, "Invalid comment type: " + commentType, "INVALID_COMMENT_TYPE");
return;
}
if ("GET".equals(method)) {
// Get existing comment
String comment = getCommentByType(program, address, commentType);
Map<String, Object> result = new HashMap<>();
result.put("address", addressStr);
result.put("comment_type", commentType);
result.put("comment", comment != null ? comment : "");
ResponseBuilder builder = new ResponseBuilder(exchange, port)
.success(true)
.result(result)
.addLink("self", "/memory/" + addressStr + "/comments/" + commentType);
sendJsonResponse(exchange, builder.build(), 200);
} else if ("POST".equals(method)) {
// Set comment
Map<String, String> params = parseJsonPostParams(exchange);
String comment = params.get("comment");
if (comment == null) {
sendErrorResponse(exchange, 400, "Comment parameter is required", "MISSING_PARAMETER");
return;
}
boolean success = setCommentByType(program, address, commentType, comment);
if (success) {
Map<String, Object> result = new HashMap<>();
result.put("address", addressStr);
result.put("comment_type", commentType);
result.put("comment", comment);
ResponseBuilder builder = new ResponseBuilder(exchange, port)
.success(true)
.result(result)
.addLink("self", "/memory/" + addressStr + "/comments/" + commentType);
sendJsonResponse(exchange, builder.build(), 200);
} else {
sendErrorResponse(exchange, 500, "Failed to set comment", "COMMENT_SET_FAILED");
}
} else {
sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED");
}
} catch (Exception e) {
Msg.error(this, "Error handling memory comments", e);
sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage(), "INTERNAL_ERROR");
}
}
/**
* Check if the comment type is valid
*/
private boolean isValidCommentType(String commentType) {
return commentType.equals("plate") ||
commentType.equals("pre") ||
commentType.equals("post") ||
commentType.equals("eol") ||
commentType.equals("repeatable");
}
/**
* Get a comment by type at the specified address
*/
private String getCommentByType(Program program, Address address, String commentType) {
if (program == null) return null;
int type = getCommentTypeInt(commentType);
return program.getListing().getComment(type, address);
}
/**
* Set a comment by type at the specified address
*/
private boolean setCommentByType(Program program, Address address, String commentType, String comment) {
if (program == null) return false;
int type = getCommentTypeInt(commentType);
try {
return TransactionHelper.executeInTransaction(program, "Set Comment", () -> {
program.getListing().setComment(address, type, comment);
return true;
});
} catch (Exception e) {
Msg.error(this, "Error setting comment", e);
return false;
}
}
/**
* Convert comment type string to Ghidra's internal comment type constants
*/
private int getCommentTypeInt(String commentType) {
switch (commentType.toLowerCase()) {
case "plate":
return CodeUnit.PLATE_COMMENT;
case "pre":
return CodeUnit.PRE_COMMENT;
case "post":
return CodeUnit.POST_COMMENT;
case "eol":
return CodeUnit.EOL_COMMENT;
case "repeatable":
return CodeUnit.REPEATABLE_COMMENT;
default:
return CodeUnit.PLATE_COMMENT;
}
}
private void handleMemoryBlocksRequest(HttpExchange exchange) throws IOException { private void handleMemoryBlocksRequest(HttpExchange exchange) throws IOException {
try { try {
if ("GET".equals(exchange.getRequestMethod())) { if ("GET".equals(exchange.getRequestMethod())) {

View File

@ -50,6 +50,8 @@ public class ProgramEndpoints extends AbstractEndpoint {
server.createContext("/address", this::handleCurrentAddress); server.createContext("/address", this::handleCurrentAddress);
server.createContext("/function", this::handleCurrentFunction); server.createContext("/function", this::handleCurrentFunction);
// Register direct analysis endpoints according to HATEOAS API
server.createContext("/analysis/callgraph", this::handleCallGraph);
} }
@Override @Override
@ -1283,8 +1285,8 @@ public class ProgramEndpoints extends AbstractEndpoint {
sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED"); sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED");
} }
} else if (path.equals("/callgraph") || path.startsWith("/callgraph/")) { } else if (path.equals("/callgraph") || path.startsWith("/callgraph/")) {
// Handle call graph generation // Handle call graph generation - for backward compatibility
handleCallGraph(exchange, program, path); handleCallGraph(exchange);
} else if (path.equals("/dataflow") || path.startsWith("/dataflow/")) { } else if (path.equals("/dataflow") || path.startsWith("/dataflow/")) {
// Handle data flow analysis // Handle data flow analysis
handleDataFlow(exchange, program, path); handleDataFlow(exchange, program, path);
@ -1460,7 +1462,14 @@ public class ProgramEndpoints extends AbstractEndpoint {
/** /**
* Handle call graph generation * Handle call graph generation
*/ */
private void handleCallGraph(HttpExchange exchange, Program program, String path) throws IOException { private void handleCallGraph(HttpExchange exchange) throws IOException {
Program program = getCurrentProgram();
if (program == null) {
sendErrorResponse(exchange, 404, "No program is currently open", "NO_PROGRAM_OPEN");
return;
}
String path = "";
if (!"GET".equals(exchange.getRequestMethod())) { if (!"GET".equals(exchange.getRequestMethod())) {
sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED"); sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED");
return; return;

View File

@ -13,8 +13,10 @@ import ghidra.program.model.data.DataTypeManager;
import ghidra.program.model.listing.Function; import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.FunctionManager; import ghidra.program.model.listing.FunctionManager;
import ghidra.program.model.listing.Parameter; import ghidra.program.model.listing.Parameter;
import ghidra.program.model.listing.ParameterImpl;
import ghidra.program.model.listing.Program; import ghidra.program.model.listing.Program;
import ghidra.program.model.listing.Variable; import ghidra.program.model.listing.Variable;
import ghidra.program.model.symbol.SourceType;
import ghidra.program.model.pcode.HighFunction; import ghidra.program.model.pcode.HighFunction;
import ghidra.program.model.pcode.HighVariable; import ghidra.program.model.pcode.HighVariable;
import ghidra.program.model.pcode.PcodeOp; import ghidra.program.model.pcode.PcodeOp;
@ -397,4 +399,87 @@ public class GhidraUtil {
return variables; return variables;
} }
/**
* Applies a function signature to an existing function.
* @param function The function to update
* @param signatureStr The C-style function signature string
* @return true if successful, false otherwise
*/
public static boolean setFunctionSignature(Function function, String signatureStr) {
if (function == null || signatureStr == null || signatureStr.isEmpty()) {
return false;
}
Program program = function.getProgram();
if (program == null) {
return false;
}
try {
// Create a function signature parser
ghidra.app.util.parser.FunctionSignatureParser parser =
new ghidra.app.util.parser.FunctionSignatureParser(
program.getDataTypeManager(), null);
// Parse the signature string
ghidra.program.model.data.FunctionDefinitionDataType functionDef =
parser.parse(function.getSignature(), signatureStr);
if (functionDef == null) {
return false;
}
// Get source type for update
ghidra.program.model.symbol.SourceType sourceType =
ghidra.program.model.symbol.SourceType.USER_DEFINED;
// Get the parameters from the function definition
ghidra.program.model.data.ParameterDefinition[] paramDefs =
functionDef.getArguments();
try {
// Get return type from the function definition
ghidra.program.model.data.DataType returnType = functionDef.getReturnType();
// Set the return type
function.setReturnType(returnType, sourceType);
// Get calling convention if available
if (functionDef.getCallingConvention() != null) {
String callingConvention = functionDef.getCallingConvention().getName();
function.setCallingConvention(callingConvention);
}
// Remove all existing parameters
while (function.getParameterCount() > 0) {
function.removeParameter(0);
}
// Add each parameter
if (paramDefs != null) {
for (int i = 0; i < paramDefs.length; i++) {
ghidra.program.model.data.ParameterDefinition paramDef = paramDefs[i];
String name = paramDef.getName();
ghidra.program.model.data.DataType dataType = paramDef.getDataType();
// Create parameter and then add it
Parameter param = new ParameterImpl(name, dataType, program);
function.addParameter(param, sourceType);
}
}
return true;
} catch (ghidra.util.exception.InvalidInputException e) {
ghidra.util.Msg.error(GhidraUtil.class,
"Error setting function parameters: " + e.getMessage(), e);
return false;
}
}
catch (Exception e) {
ghidra.util.Msg.error(GhidraUtil.class,
"Error setting function signature: " + e.getMessage(), e);
return false;
}
}
} }