From 9b19011b7d2a8e2d974829d142cba5524a6027e5 Mon Sep 17 00:00:00 2001 From: Teal Bauer Date: Mon, 14 Apr 2025 11:25:22 +0200 Subject: [PATCH] finalize HATEOAS updates --- bridge_mcp_hydra.py | 484 +++++++++++++++--- .../ghidra/endpoints/AnalysisEndpoints.java | 3 + .../ghidra/endpoints/FunctionEndpoints.java | 36 +- .../ghidra/endpoints/MemoryEndpoints.java | 206 +++++++- .../ghidra/endpoints/ProgramEndpoints.java | 15 +- .../eu/starsong/ghidra/util/GhidraUtil.java | 85 +++ 6 files changed, 761 insertions(+), 68 deletions(-) diff --git a/bridge_mcp_hydra.py b/bridge_mcp_hydra.py index 9aa092f..f809fad 100644 --- a/bridge_mcp_hydra.py +++ b/bridge_mcp_hydra.py @@ -82,7 +82,13 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None, headers: dict = None) -> dict: """Internal helper to make HTTP requests and handle common errors.""" url = f"{get_instance_url(port)}/{endpoint}" - request_headers = {'Accept': 'application/json'} + + # Set up headers according to HATEOAS API expected format + request_headers = { + 'Accept': 'application/json', + 'X-Request-ID': f"mcp-bridge-{int(time.time() * 1000)}" + } + if headers: request_headers.update(headers) @@ -93,7 +99,10 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None, if not validate_origin(check_headers): return { "success": False, - "error": "Origin not allowed", + "error": { + "code": "ORIGIN_NOT_ALLOWED", + "message": "Origin not allowed for state-changing request" + }, "status_code": 403, "timestamp": int(time.time() * 1000) } @@ -115,15 +124,32 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None, try: parsed_json = response.json() + # Add timestamp if not present if isinstance(parsed_json, dict) and "timestamp" not in parsed_json: parsed_json["timestamp"] = int(time.time() * 1000) + + # Check for HATEOAS compliant error response format and reformat if needed + if not response.ok and isinstance(parsed_json, dict) and "success" in parsed_json and not parsed_json["success"]: + # Check if error is in the expected HATEOAS format + if "error" in parsed_json and not isinstance(parsed_json["error"], dict): + # Convert string error to the proper format + error_message = parsed_json["error"] + parsed_json["error"] = { + "code": f"HTTP_{response.status_code}", + "message": error_message + } + return parsed_json + except ValueError: if response.ok: return { "success": False, - "error": "Received non-JSON success response from Ghidra plugin", + "error": { + "code": "NON_JSON_RESPONSE", + "message": "Received non-JSON success response from Ghidra plugin" + }, "status_code": response.status_code, "response_text": response.text[:500], "timestamp": int(time.time() * 1000) @@ -131,7 +157,10 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None, else: return { "success": False, - "error": f"HTTP {response.status_code} - Non-JSON error response", + "error": { + "code": f"HTTP_{response.status_code}", + "message": f"Non-JSON error response: {response.text[:100]}..." + }, "status_code": response.status_code, "response_text": response.text[:500], "timestamp": int(time.time() * 1000) @@ -140,21 +169,30 @@ def _make_request(method: str, port: int, endpoint: str, params: dict = None, except requests.exceptions.Timeout: return { "success": False, - "error": "Request timed out", + "error": { + "code": "REQUEST_TIMEOUT", + "message": "Request timed out" + }, "status_code": 408, "timestamp": int(time.time() * 1000) } except requests.exceptions.ConnectionError: return { "success": False, - "error": f"Failed to connect to Ghidra instance at {url}", + "error": { + "code": "CONNECTION_ERROR", + "message": f"Failed to connect to Ghidra instance at {url}" + }, "status_code": 503, "timestamp": int(time.time() * 1000) } except Exception as e: return { "success": False, - "error": f"An unexpected error occurred: {str(e)}", + "error": { + "code": "UNEXPECTED_ERROR", + "message": f"An unexpected error occurred: {str(e)}" + }, "exception": e.__class__.__name__, "timestamp": int(time.time() * 1000) } @@ -211,6 +249,12 @@ def simplify_response(response: dict) -> dict: # Make a copy to avoid modifying the original result = response.copy() + # Store API response metadata + api_metadata = {} + for key in ["id", "instance", "timestamp", "size", "offset", "limit"]: + if key in result: + api_metadata[key] = result.get(key) + # Simplify the main result data if present if "result" in result: # Handle array results @@ -218,9 +262,17 @@ def simplify_response(response: dict) -> dict: simplified_items = [] for item in result["result"]: if isinstance(item, dict): - # Remove HATEOAS links from individual items + # Store but remove HATEOAS links from individual items item_copy = item.copy() - item_copy.pop("_links", None) + links = item_copy.pop("_links", None) + + # Optionally store direct href links as more accessible properties + # This helps AI agents navigate the API without understanding HATEOAS + if isinstance(links, dict): + for link_name, link_data in links.items(): + if isinstance(link_data, dict) and "href" in link_data: + item_copy[f"{link_name}_url"] = link_data["href"] + simplified_items.append(item_copy) else: simplified_items.append(item) @@ -229,8 +281,15 @@ def simplify_response(response: dict) -> dict: # Handle object results elif isinstance(result["result"], dict): result_copy = result["result"].copy() - # Remove links from result object - result_copy.pop("_links", None) + + # Store but remove links from result object + links = result_copy.pop("_links", None) + + # Add direct href links for easier navigation + if isinstance(links, dict): + for link_name, link_data in links.items(): + if isinstance(link_data, dict) and "href" in link_data: + result_copy[f"{link_name}_url"] = link_data["href"] # Special case for disassembly - convert to text for easier consumption if "instructions" in result_copy and isinstance(result_copy["instructions"], list): @@ -256,8 +315,24 @@ def simplify_response(response: dict) -> dict: result["result"] = result_copy - # Remove HATEOAS links from the top level - result.pop("_links", None) + # Store but remove HATEOAS links from the top level + links = result.pop("_links", None) + + # Add direct href links in a more accessible format + if isinstance(links, dict): + api_links = {} + for link_name, link_data in links.items(): + if isinstance(link_data, dict) and "href" in link_data: + api_links[link_name] = link_data["href"] + + # Add simplified links + if api_links: + result["api_links"] = api_links + + # Restore API metadata + for key, value in api_metadata.items(): + if key not in result: + result[key] = value return result @@ -310,6 +385,17 @@ def register_instance(port: int, url: str = None) -> str: project_info = {"url": url} try: + # Check plugin version to ensure compatibility + try: + version_data = response.json() + if "result" in version_data: + result = version_data["result"] + if isinstance(result, dict): + project_info["plugin_version"] = result.get("plugin_version", "") + project_info["api_version"] = result.get("api_version", 0) + except Exception as e: + print(f"Error parsing plugin version: {e}", file=sys.stderr) + # Get program info from HATEOAS API info_url = f"{url}/program" @@ -321,12 +407,27 @@ def register_instance(port: int, url: str = None) -> str: if "result" in info_data: result = info_data["result"] if isinstance(result, dict): - project_info["project"] = result.get("project", "") + # Extract project and file from programId (format: "project:/file") + program_id = result.get("programId", "") + if ":" in program_id: + project_name, file_path = program_id.split(":", 1) + project_info["project"] = project_name + # Remove leading slash from file path if present + if file_path.startswith("/"): + file_path = file_path[1:] + project_info["path"] = file_path + + # Get file name directly from the result project_info["file"] = result.get("name", "") - project_info["path"] = result.get("path", "") - project_info["language_id"] = result.get("language_id", "") - project_info["compiler_spec_id"] = result.get("compiler_spec_id", "") + + # Get other metadata + project_info["language_id"] = result.get("languageId", "") + project_info["compiler_spec_id"] = result.get("compilerSpecId", "") project_info["image_base"] = result.get("image_base", "") + + # Store _links from result for HATEOAS navigation + if "_links" in result: + project_info["_links"] = result.get("_links", {}) except Exception as e: print(f"Error parsing info endpoint: {e}", file=sys.stderr) except Exception as e: @@ -386,11 +487,47 @@ def _discover_instances(port_range, host=None, timeout=0.5) -> dict: try: # Try HATEOAS API via plugin-version endpoint test_url = f"{url}/plugin-version" - response = requests.get(test_url, timeout=timeout) + response = requests.get(test_url, + headers={'Accept': 'application/json', + 'X-Request-ID': f"discovery-{int(time.time() * 1000)}"}, + timeout=timeout) + if response.ok: - result = register_instance(port, url) - found_instances.append( - {"port": port, "url": url, "result": result}) + # Further validate it's a GhydraMCP instance by checking response format + try: + json_data = response.json() + if "success" in json_data and json_data["success"] and "result" in json_data: + # Looks like a valid HATEOAS API response + # Instead of relying only on register_instance, which already checks program info, + # extract additional information here for more detailed discovery results + result = register_instance(port, url) + + # Initialize report info + instance_info = { + "port": port, + "url": url + } + + # Extract version info for reporting + if isinstance(json_data["result"], dict): + instance_info["plugin_version"] = json_data["result"].get("plugin_version", "unknown") + instance_info["api_version"] = json_data["result"].get("api_version", "unknown") + else: + instance_info["plugin_version"] = "unknown" + instance_info["api_version"] = "unknown" + + # Include project details from registered instance in the report + if port in active_instances: + instance_info["project"] = active_instances[port].get("project", "") + instance_info["file"] = active_instances[port].get("file", "") + + instance_info["result"] = result + found_instances.append(instance_info) + except (ValueError, KeyError): + # Not a valid JSON response or missing expected keys + print(f"Port {port} returned non-HATEOAS response", file=sys.stderr) + continue + except requests.exceptions.RequestException: # Instance not available, just continue continue @@ -574,12 +711,15 @@ def disassemble_function(port: int = DEFAULT_GHIDRA_PORT, name: Function name (mutually exclusive with address) Returns: - dict: Contains function information and disassembly text + dict: Contains function information and disassembly text, optimized for agent consumption """ if not address and not name: return { "success": False, - "error": "Either address or name parameter is required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Either address or name parameter is required" + }, "timestamp": int(time.time() * 1000) } @@ -591,26 +731,57 @@ def disassemble_function(port: int = DEFAULT_GHIDRA_PORT, response = safe_get(port, endpoint) simplified = simplify_response(response) - # For AI consumption, add a plain text version of the disassembly if not already present + # For AI consumption, create a simplified response with just the disassembly text if "result" in simplified and isinstance(simplified["result"], dict): - if "instructions" in simplified["result"] and isinstance(simplified["result"]["instructions"], list): - if "disassembly_text" not in simplified["result"]: - instr_list = simplified["result"]["instructions"] - disasm_text = "" - for instr in instr_list: - if isinstance(instr, dict): - addr = instr.get("address", "") - mnemonic = instr.get("mnemonic", "") - operands = instr.get("operands", "") - bytes_str = instr.get("bytes", "") - - # Format: address: bytes mnemonic operands - disasm_text += f"{addr}: {bytes_str.ljust(10)} {mnemonic} {operands}\n" + result = simplified["result"] + function_info = None + disasm_text = None + + # Extract function info if available + if "function" in result and isinstance(result["function"], dict): + function_info = result["function"] + + # Get the disassembly text, generate it if it doesn't exist + if "disassembly_text" in result: + disasm_text = result["disassembly_text"] + elif "instructions" in result and isinstance(result["instructions"], list): + instr_list = result["instructions"] + disasm_text = "" + for instr in instr_list: + if isinstance(instr, dict): + addr = instr.get("address", "") + mnemonic = instr.get("mnemonic", "") + operands = instr.get("operands", "") + bytes_str = instr.get("bytes", "") + + # Format: address: bytes mnemonic operands + disasm_text += f"{addr}: {bytes_str.ljust(10)} {mnemonic} {operands}\n" + + # Create a simplified result that's easier for agents to consume + if disasm_text: + # Create a new response with just the important info + new_response = { + "success": True, + "id": simplified.get("id", ""), + "instance": simplified.get("instance", ""), + "timestamp": simplified.get("timestamp", int(time.time() * 1000)), + "disassembly": disasm_text # Direct access to disassembly text + } + + # Add function info if available + if function_info: + new_response["function_name"] = function_info.get("name", "") + new_response["function_address"] = function_info.get("address", "") + if "signature" in function_info: + new_response["function_signature"] = function_info.get("signature", "") - simplified["result"]["disassembly_text"] = disasm_text - # Also make it more directly accessible - simplified["disassembly_text"] = disasm_text + # Preserve API links if available + if "api_links" in simplified: + new_response["api_links"] = simplified["api_links"] + + return new_response + # If we couldn't extract disassembly text, return the original response return simplified @@ -631,7 +802,10 @@ def get_function_variables(port: int = DEFAULT_GHIDRA_PORT, if not address and not name: return { "success": False, - "error": "Either address or name parameter is required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Either address or name parameter is required" + }, "timestamp": int(time.time() * 1000) } @@ -737,6 +911,50 @@ def list_symbols(port: int = DEFAULT_GHIDRA_PORT, return simplified +@mcp.tool() +def list_variables(port: int = DEFAULT_GHIDRA_PORT, + offset: int = 0, + limit: int = 100, + search: str = None, + global_only: bool = False) -> dict: + """List all variables in the program with pagination + + Args: + port: Ghidra instance port (default: 8192) + offset: Pagination offset (default: 0) + limit: Maximum items to return (default: 100) + search: Optional search term to filter variables by name + global_only: If True, only return global variables (default: False) + + Returns: + dict: Contains list of variables with metadata and pagination info + """ + params = { + "offset": offset, + "limit": limit + } + + if search: + params["search"] = search + + if global_only: + params["global_only"] = str(global_only).lower() + + response = safe_get(port, "variables", params) + simplified = simplify_response(response) + + # Ensure we maintain pagination metadata + if isinstance(simplified, dict) and "error" not in simplified: + result_size = 0 + if "result" in simplified and isinstance(simplified["result"], list): + result_size = len(simplified["result"]) + simplified.setdefault("size", result_size) + simplified.setdefault("offset", offset) + simplified.setdefault("limit", limit) + + return simplified + + @mcp.tool() def list_data_items(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, @@ -791,7 +1009,7 @@ def list_data_items(port: int = DEFAULT_GHIDRA_PORT, @mcp.tool() def read_memory(port: int = DEFAULT_GHIDRA_PORT, - address: str = "", + address: str = None, length: int = 16, format: str = "hex") -> dict: """Read bytes from memory @@ -807,35 +1025,50 @@ def read_memory(port: int = DEFAULT_GHIDRA_PORT, "address": original address, "length": bytes read, "format": output format, - "bytes": the memory contents as a string in the specified format, + "hexBytes": the memory contents as hex string, + "rawBytes": the memory contents as base64 string, "timestamp": response timestamp } """ if not address: return { "success": False, - "error": "Address parameter is required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Address parameter is required" + }, "timestamp": int(time.time() * 1000) } + # Use query parameters instead of path parameters for more reliable handling params = { + "address": address, "length": length, "format": format } - response = safe_get(port, f"memory/{address}", params) + response = safe_get(port, "memory", params) simplified = simplify_response(response) # Ensure the result is simple and directly usable if "result" in simplified and isinstance(simplified["result"], dict): - bytes_data = simplified["result"].get("bytes", "") + result = simplified["result"] + + # Pass through all representations of the bytes memory_info = { - "address": address, - "length": length, + "success": True, + "address": result.get("address", address), + "length": result.get("bytesRead", length), "format": format, - "bytes": bytes_data, "timestamp": simplified.get("timestamp", int(time.time() * 1000)) } + + # Include all the different byte representations + if "hexBytes" in result: + memory_info["hexBytes"] = result["hexBytes"] + if "rawBytes" in result: + memory_info["rawBytes"] = result["rawBytes"] + return memory_info return simplified @@ -843,8 +1076,8 @@ def read_memory(port: int = DEFAULT_GHIDRA_PORT, @mcp.tool() def write_memory(port: int = DEFAULT_GHIDRA_PORT, - address: str = "", - bytes_data: str = "", + address: str = None, + bytes_data: str = None, format: str = "hex") -> dict: """Write bytes to memory (use with caution) @@ -860,10 +1093,23 @@ def write_memory(port: int = DEFAULT_GHIDRA_PORT, - length: number of bytes written - bytesWritten: confirmation of bytes written """ - if not address or not bytes_data: + if not address: return { "success": False, - "error": "Address and bytes parameters are required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Address parameter is required" + }, + "timestamp": int(time.time() * 1000) + } + + if not bytes_data: + return { + "success": False, + "error": { + "code": "MISSING_PARAMETER", + "message": "Bytes parameter is required" + }, "timestamp": int(time.time() * 1000) } @@ -906,7 +1152,10 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT, if not to_addr and not from_addr: return { "success": False, - "error": "Either to_addr or from_addr parameter is required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Either to_addr or from_addr parameter is required" + }, "timestamp": int(time.time() * 1000) } @@ -946,11 +1195,21 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT, if "from_function" in ref and isinstance(ref["from_function"], dict): flat_ref["from_function"] = ref["from_function"].get("name") flat_ref["from_function_addr"] = ref["from_function"].get("address") + + # Add navigational URLs for HATEOAS + if ref["from_function"].get("address"): + flat_ref["from_function_decompile_url"] = f"functions/{ref['from_function'].get('address')}/decompile" + flat_ref["from_function_disassembly_url"] = f"functions/{ref['from_function'].get('address')}/disassembly" # Add target function info if available if "to_function" in ref and isinstance(ref["to_function"], dict): flat_ref["to_function"] = ref["to_function"].get("name") flat_ref["to_function_addr"] = ref["to_function"].get("address") + + # Add navigational URLs for HATEOAS + if ref["to_function"].get("address"): + flat_ref["to_function_decompile_url"] = f"functions/{ref['to_function'].get('address')}/decompile" + flat_ref["to_function_disassembly_url"] = f"functions/{ref['to_function'].get('address')}/disassembly" # Add symbol info if available if "from_symbol" in ref: @@ -963,6 +1222,12 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT, flat_ref["from_instruction"] = ref["from_instruction"] if "to_instruction" in ref: flat_ref["to_instruction"] = ref["to_instruction"] + + # Add other useful HATEOAS links + if flat_ref.get("from_addr"): + flat_ref["from_memory_url"] = f"memory/{flat_ref['from_addr']}" + if flat_ref.get("to_addr"): + flat_ref["to_memory_url"] = f"memory/{flat_ref['to_addr']}" flat_refs.append(flat_ref) @@ -979,6 +1244,28 @@ def list_xrefs(port: int = DEFAULT_GHIDRA_PORT, text_refs.append(line) simplified["xrefs_text"] = "\n".join(text_refs) + + # Add navigation links for next/previous pages + if offset > 0: + prev_offset = max(0, offset - limit) + simplified["prev_page_url"] = f"xrefs?offset={prev_offset}&limit={limit}" + if to_addr: + simplified["prev_page_url"] += f"&to_addr={to_addr}" + if from_addr: + simplified["prev_page_url"] += f"&from_addr={from_addr}" + if type: + simplified["prev_page_url"] += f"&type={type}" + + total_size = simplified.get("size", 0) + if offset + limit < total_size: + next_offset = offset + limit + simplified["next_page_url"] = f"xrefs?offset={next_offset}&limit={limit}" + if to_addr: + simplified["next_page_url"] += f"&to_addr={to_addr}" + if from_addr: + simplified["next_page_url"] += f"&from_addr={from_addr}" + if type: + simplified["next_page_url"] += f"&type={type}" return simplified @@ -1086,7 +1373,10 @@ def rename_function(port: int = DEFAULT_GHIDRA_PORT, if not (address or name) or not new_name: return { "success": False, - "error": "Either address or name, and new_name parameters are required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Either address or name, and new_name parameters are required" + }, "timestamp": int(time.time() * 1000) } @@ -1343,7 +1633,7 @@ def get_dataflow(port: int = DEFAULT_GHIDRA_PORT, @mcp.tool() def set_comment(port: int = DEFAULT_GHIDRA_PORT, - address: str = "", + address: str = None, comment: str = "", comment_type: str = "plate") -> dict: """Set a comment at the specified address @@ -1366,7 +1656,10 @@ def set_comment(port: int = DEFAULT_GHIDRA_PORT, if not address: return { "success": False, - "error": "Address parameter is required", + "error": { + "code": "MISSING_PARAMETER", + "message": "Address parameter is required" + }, "timestamp": int(time.time() * 1000) } @@ -1378,6 +1671,48 @@ def set_comment(port: int = DEFAULT_GHIDRA_PORT, return simplify_response(response) +@mcp.tool() +def set_decompiler_comment(port: int = DEFAULT_GHIDRA_PORT, + address: str = None, + comment: str = "") -> dict: + """Set a decompiler comment at the specified address + + Args: + port: Ghidra instance port (default: 8192) + address: Memory address in hex format + comment: Comment text + + Returns: + dict: Operation result + """ + if not address: + return { + "success": False, + "error": { + "code": "MISSING_PARAMETER", + "message": "Address parameter is required" + }, + "timestamp": int(time.time() * 1000) + } + + # Decompiler comments are typically "plate" comments in Ghidra + payload = { + "comment": comment + } + + # First try to post to the more specific decompiler endpoint if it exists + try: + response = safe_post(port, f"functions/{address}/comments", payload) + if response.get("success", False): + return simplify_response(response) + except Exception as e: + # Fall back to the general memory comments endpoint + pass + + # Fall back to the normal comment mechanism with "plate" type + return set_comment(port, address, comment, "plate") + + def handle_sigint(signum, frame): os._exit(0) @@ -1397,6 +1732,41 @@ def periodic_discovery(): response = requests.get(f"{url}/plugin-version", timeout=1) if not response.ok: ports_to_remove.append(port) + continue + + # Update program info if available (especially to get project name) + try: + info_url = f"{url}/program" + info_response = requests.get(info_url, timeout=1) + if info_response.ok: + try: + info_data = info_response.json() + if "result" in info_data: + result = info_data["result"] + if isinstance(result, dict): + # Extract project and file from programId (format: "project:/file") + program_id = result.get("programId", "") + if ":" in program_id: + project_name, file_path = program_id.split(":", 1) + info["project"] = project_name + # Remove leading slash from file path if present + if file_path.startswith("/"): + file_path = file_path[1:] + info["path"] = file_path + + # Get file name directly from the result + info["file"] = result.get("name", "") + + # Get other metadata + info["language_id"] = result.get("languageId", "") + info["compiler_spec_id"] = result.get("compilerSpecId", "") + info["image_base"] = result.get("image_base", "") + except Exception as e: + print(f"Error parsing info endpoint during discovery: {e}", file=sys.stderr) + except Exception: + # Non-critical, continue even if update fails + pass + except requests.exceptions.RequestException: ports_to_remove.append(port) diff --git a/src/main/java/eu/starsong/ghidra/endpoints/AnalysisEndpoints.java b/src/main/java/eu/starsong/ghidra/endpoints/AnalysisEndpoints.java index 822d5d7..6f9ed05 100644 --- a/src/main/java/eu/starsong/ghidra/endpoints/AnalysisEndpoints.java +++ b/src/main/java/eu/starsong/ghidra/endpoints/AnalysisEndpoints.java @@ -33,6 +33,9 @@ public class AnalysisEndpoints extends AbstractEndpoint { @Override public void registerEndpoints(HttpServer server) { server.createContext("/analysis", this::handleAnalysisRequest); + + // NOTE: The callgraph endpoint is now registered in ProgramEndpoints + // This comment is to avoid confusion during future maintenance } private void handleAnalysisRequest(HttpExchange exchange) throws IOException { diff --git a/src/main/java/eu/starsong/ghidra/endpoints/FunctionEndpoints.java b/src/main/java/eu/starsong/ghidra/endpoints/FunctionEndpoints.java index aca29dc..073a942 100644 --- a/src/main/java/eu/starsong/ghidra/endpoints/FunctionEndpoints.java +++ b/src/main/java/eu/starsong/ghidra/endpoints/FunctionEndpoints.java @@ -406,9 +406,21 @@ public class FunctionEndpoints extends AbstractEndpoint { } if (signature != null && !signature.isEmpty()) { - // Update signature - placeholder - sendErrorResponse(exchange, 501, "Updating function signature not implemented", "NOT_IMPLEMENTED"); - return; + // Update function signature using our utility method + try { + boolean success = TransactionHelper.executeInTransaction(program, "Set Function Signature", () -> { + return GhidraUtil.setFunctionSignature(function, signature); + }); + + if (!success) { + sendErrorResponse(exchange, 400, "Failed to set function signature: invalid signature format", "SIGNATURE_FAILED"); + return; + } + changed = true; + } catch (Exception e) { + sendErrorResponse(exchange, 400, "Failed to set function signature: " + e.getMessage(), "SIGNATURE_FAILED"); + return; + } } if (comment != null) { @@ -830,9 +842,21 @@ public class FunctionEndpoints extends AbstractEndpoint { } if (signature != null && !signature.isEmpty()) { - // Update signature - sendErrorResponse(exchange, 501, "Updating function signature not implemented", "NOT_IMPLEMENTED"); - return; + // Update function signature using our utility method + try { + boolean success = TransactionHelper.executeInTransaction(program, "Set Function Signature", () -> { + return GhidraUtil.setFunctionSignature(function, signature); + }); + + if (!success) { + sendErrorResponse(exchange, 400, "Failed to set function signature: invalid signature format", "SIGNATURE_FAILED"); + return; + } + changed = true; + } catch (Exception e) { + sendErrorResponse(exchange, 400, "Failed to set function signature: " + e.getMessage(), "SIGNATURE_FAILED"); + return; + } } if (comment != null) { diff --git a/src/main/java/eu/starsong/ghidra/endpoints/MemoryEndpoints.java b/src/main/java/eu/starsong/ghidra/endpoints/MemoryEndpoints.java index a2cebb2..2c61ae6 100644 --- a/src/main/java/eu/starsong/ghidra/endpoints/MemoryEndpoints.java +++ b/src/main/java/eu/starsong/ghidra/endpoints/MemoryEndpoints.java @@ -5,11 +5,13 @@ import com.google.gson.JsonObject; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpServer; import eu.starsong.ghidra.api.ResponseBuilder; +import eu.starsong.ghidra.util.TransactionHelper; import ghidra.program.model.address.Address; import ghidra.program.model.address.AddressFactory; import ghidra.program.model.mem.Memory; import ghidra.program.model.mem.MemoryAccessException; import ghidra.program.model.mem.MemoryBlock; +import ghidra.program.model.listing.CodeUnit; import ghidra.program.model.listing.Program; import ghidra.framework.plugintool.PluginTool; import ghidra.util.Msg; @@ -39,8 +41,25 @@ public class MemoryEndpoints extends AbstractEndpoint { @Override public void registerEndpoints(HttpServer server) { + // Per HttpServer docs: paths are matched by longest matching prefix + // So register specific endpoints first, then more general ones + + // Comments endpoint path needs to be registered with a specific context path + // Example: /memory/0x1000/comments/plate needs a specific handler + server.createContext("/memory/", exchange -> { + String path = exchange.getRequestURI().getPath(); + if (path.contains("/comments/")) { + handleMemoryAddressRequest(exchange); + } else if (path.equals("/memory/blocks")) { + handleMemoryBlocksRequest(exchange); + } else { + // Handle as general memory address request + handleMemoryAddressRequest(exchange); + } + }); + + // Register the most general endpoint last server.createContext("/memory", this::handleMemoryRequest); - server.createContext("/memory/blocks", this::handleMemoryBlocksRequest); } private void handleMemoryRequest(HttpExchange exchange) throws IOException { @@ -170,7 +189,190 @@ public class MemoryEndpoints extends AbstractEndpoint { } } - private void handleMemoryBlocksRequest(HttpExchange exchange) throws IOException { + /** + * Handle requests to /memory/{address} including child resources like comments + */ +private void handleMemoryAddressRequest(HttpExchange exchange) throws IOException { + try { + // Extract address from path: /memory/{address}/... + String path = exchange.getRequestURI().getPath(); + if (path.equals("/memory/") || path.equals("/memory")) { + handleMemoryRequest(exchange); + return; + } + + // Parse address from path + String remainingPath = path.substring("/memory/".length()); + + // Check if this is a request for a specific address's comments + if (remainingPath.contains("/comments/")) { + // Format: /memory/{address}/comments/{comment_type} + String[] parts = remainingPath.split("/comments/", 2); + String addressStr = parts[0]; + String commentType = parts.length > 1 ? parts[1] : "plate"; // Default to plate comments + + handleMemoryComments(exchange, addressStr, commentType); + return; + } + + // Otherwise, treat as a direct memory request with address in the path + String addressStr = remainingPath; + Map params = parseQueryParams(exchange); + + // Handle same as the query parameter version + params.put("address", addressStr); + exchange.setAttribute("address", addressStr); + + // Delegate to the main memory handler + handleMemoryRequest(exchange); + } catch (Exception e) { + Msg.error(this, "Error handling memory address endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage(), "INTERNAL_ERROR"); + } +} + +/** + * Handle requests to set or get comments at a specific memory address + */ +private void handleMemoryComments(HttpExchange exchange, String addressStr, String commentType) throws IOException { + try { + String method = exchange.getRequestMethod(); + Program program = getCurrentProgram(); + + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded", "NO_PROGRAM_LOADED"); + return; + } + + // Parse address + AddressFactory addressFactory = program.getAddressFactory(); + Address address; + try { + address = addressFactory.getAddress(addressStr); + } catch (Exception e) { + sendErrorResponse(exchange, 400, "Invalid address format: " + addressStr, "INVALID_ADDRESS"); + return; + } + + // Validate comment type + if (!isValidCommentType(commentType)) { + sendErrorResponse(exchange, 400, "Invalid comment type: " + commentType, "INVALID_COMMENT_TYPE"); + return; + } + + if ("GET".equals(method)) { + // Get existing comment + String comment = getCommentByType(program, address, commentType); + + Map result = new HashMap<>(); + result.put("address", addressStr); + result.put("comment_type", commentType); + result.put("comment", comment != null ? comment : ""); + + ResponseBuilder builder = new ResponseBuilder(exchange, port) + .success(true) + .result(result) + .addLink("self", "/memory/" + addressStr + "/comments/" + commentType); + + sendJsonResponse(exchange, builder.build(), 200); + + } else if ("POST".equals(method)) { + // Set comment + Map params = parseJsonPostParams(exchange); + String comment = params.get("comment"); + + if (comment == null) { + sendErrorResponse(exchange, 400, "Comment parameter is required", "MISSING_PARAMETER"); + return; + } + + boolean success = setCommentByType(program, address, commentType, comment); + + if (success) { + Map result = new HashMap<>(); + result.put("address", addressStr); + result.put("comment_type", commentType); + result.put("comment", comment); + + ResponseBuilder builder = new ResponseBuilder(exchange, port) + .success(true) + .result(result) + .addLink("self", "/memory/" + addressStr + "/comments/" + commentType); + + sendJsonResponse(exchange, builder.build(), 200); + } else { + sendErrorResponse(exchange, 500, "Failed to set comment", "COMMENT_SET_FAILED"); + } + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED"); + } + } catch (Exception e) { + Msg.error(this, "Error handling memory comments", e); + sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage(), "INTERNAL_ERROR"); + } +} + +/** + * Check if the comment type is valid + */ +private boolean isValidCommentType(String commentType) { + return commentType.equals("plate") || + commentType.equals("pre") || + commentType.equals("post") || + commentType.equals("eol") || + commentType.equals("repeatable"); +} + +/** + * Get a comment by type at the specified address + */ +private String getCommentByType(Program program, Address address, String commentType) { + if (program == null) return null; + + int type = getCommentTypeInt(commentType); + return program.getListing().getComment(type, address); +} + +/** + * Set a comment by type at the specified address + */ +private boolean setCommentByType(Program program, Address address, String commentType, String comment) { + if (program == null) return false; + + int type = getCommentTypeInt(commentType); + + try { + return TransactionHelper.executeInTransaction(program, "Set Comment", () -> { + program.getListing().setComment(address, type, comment); + return true; + }); + } catch (Exception e) { + Msg.error(this, "Error setting comment", e); + return false; + } +} + +/** + * Convert comment type string to Ghidra's internal comment type constants + */ +private int getCommentTypeInt(String commentType) { + switch (commentType.toLowerCase()) { + case "plate": + return CodeUnit.PLATE_COMMENT; + case "pre": + return CodeUnit.PRE_COMMENT; + case "post": + return CodeUnit.POST_COMMENT; + case "eol": + return CodeUnit.EOL_COMMENT; + case "repeatable": + return CodeUnit.REPEATABLE_COMMENT; + default: + return CodeUnit.PLATE_COMMENT; + } +} + +private void handleMemoryBlocksRequest(HttpExchange exchange) throws IOException { try { if ("GET".equals(exchange.getRequestMethod())) { Map qparams = parseQueryParams(exchange); diff --git a/src/main/java/eu/starsong/ghidra/endpoints/ProgramEndpoints.java b/src/main/java/eu/starsong/ghidra/endpoints/ProgramEndpoints.java index ec00246..c963f94 100644 --- a/src/main/java/eu/starsong/ghidra/endpoints/ProgramEndpoints.java +++ b/src/main/java/eu/starsong/ghidra/endpoints/ProgramEndpoints.java @@ -50,6 +50,8 @@ public class ProgramEndpoints extends AbstractEndpoint { server.createContext("/address", this::handleCurrentAddress); server.createContext("/function", this::handleCurrentFunction); + // Register direct analysis endpoints according to HATEOAS API + server.createContext("/analysis/callgraph", this::handleCallGraph); } @Override @@ -1283,8 +1285,8 @@ public class ProgramEndpoints extends AbstractEndpoint { sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED"); } } else if (path.equals("/callgraph") || path.startsWith("/callgraph/")) { - // Handle call graph generation - handleCallGraph(exchange, program, path); + // Handle call graph generation - for backward compatibility + handleCallGraph(exchange); } else if (path.equals("/dataflow") || path.startsWith("/dataflow/")) { // Handle data flow analysis handleDataFlow(exchange, program, path); @@ -1460,7 +1462,14 @@ public class ProgramEndpoints extends AbstractEndpoint { /** * Handle call graph generation */ - private void handleCallGraph(HttpExchange exchange, Program program, String path) throws IOException { + private void handleCallGraph(HttpExchange exchange) throws IOException { + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 404, "No program is currently open", "NO_PROGRAM_OPEN"); + return; + } + + String path = ""; if (!"GET".equals(exchange.getRequestMethod())) { sendErrorResponse(exchange, 405, "Method Not Allowed", "METHOD_NOT_ALLOWED"); return; diff --git a/src/main/java/eu/starsong/ghidra/util/GhidraUtil.java b/src/main/java/eu/starsong/ghidra/util/GhidraUtil.java index deca11a..0bdcaba 100644 --- a/src/main/java/eu/starsong/ghidra/util/GhidraUtil.java +++ b/src/main/java/eu/starsong/ghidra/util/GhidraUtil.java @@ -13,8 +13,10 @@ import ghidra.program.model.data.DataTypeManager; import ghidra.program.model.listing.Function; import ghidra.program.model.listing.FunctionManager; import ghidra.program.model.listing.Parameter; +import ghidra.program.model.listing.ParameterImpl; import ghidra.program.model.listing.Program; import ghidra.program.model.listing.Variable; +import ghidra.program.model.symbol.SourceType; import ghidra.program.model.pcode.HighFunction; import ghidra.program.model.pcode.HighVariable; import ghidra.program.model.pcode.PcodeOp; @@ -397,4 +399,87 @@ public class GhidraUtil { return variables; } + + /** + * Applies a function signature to an existing function. + * @param function The function to update + * @param signatureStr The C-style function signature string + * @return true if successful, false otherwise + */ + public static boolean setFunctionSignature(Function function, String signatureStr) { + if (function == null || signatureStr == null || signatureStr.isEmpty()) { + return false; + } + + Program program = function.getProgram(); + if (program == null) { + return false; + } + + try { + // Create a function signature parser + ghidra.app.util.parser.FunctionSignatureParser parser = + new ghidra.app.util.parser.FunctionSignatureParser( + program.getDataTypeManager(), null); + + // Parse the signature string + ghidra.program.model.data.FunctionDefinitionDataType functionDef = + parser.parse(function.getSignature(), signatureStr); + + if (functionDef == null) { + return false; + } + + // Get source type for update + ghidra.program.model.symbol.SourceType sourceType = + ghidra.program.model.symbol.SourceType.USER_DEFINED; + + // Get the parameters from the function definition + ghidra.program.model.data.ParameterDefinition[] paramDefs = + functionDef.getArguments(); + + try { + // Get return type from the function definition + ghidra.program.model.data.DataType returnType = functionDef.getReturnType(); + + // Set the return type + function.setReturnType(returnType, sourceType); + + // Get calling convention if available + if (functionDef.getCallingConvention() != null) { + String callingConvention = functionDef.getCallingConvention().getName(); + function.setCallingConvention(callingConvention); + } + + // Remove all existing parameters + while (function.getParameterCount() > 0) { + function.removeParameter(0); + } + + // Add each parameter + if (paramDefs != null) { + for (int i = 0; i < paramDefs.length; i++) { + ghidra.program.model.data.ParameterDefinition paramDef = paramDefs[i]; + String name = paramDef.getName(); + ghidra.program.model.data.DataType dataType = paramDef.getDataType(); + + // Create parameter and then add it + Parameter param = new ParameterImpl(name, dataType, program); + function.addParameter(param, sourceType); + } + } + + return true; + } catch (ghidra.util.exception.InvalidInputException e) { + ghidra.util.Msg.error(GhidraUtil.class, + "Error setting function parameters: " + e.getMessage(), e); + return false; + } + } + catch (Exception e) { + ghidra.util.Msg.error(GhidraUtil.class, + "Error setting function signature: " + e.getMessage(), e); + return false; + } + } }