From ba7781643fa5a849902fd1eec9e5cc7d79c2e0b7 Mon Sep 17 00:00:00 2001 From: Teal Bauer Date: Tue, 8 Apr 2025 22:57:57 +0200 Subject: [PATCH] chore: Completed conversion of bridge/plugin protocol to pure JSON --- CHANGELOG.md | 26 +- bridge_mcp_hydra.py | 108 ++--- .../eu/starsong/ghidra/GhydraMCPPlugin.java | 431 +++++++++++------- test_http_api.py | 25 + test_mcp_client.py | 18 + 5 files changed, 381 insertions(+), 227 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 22cb97e..29fb14a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,16 +11,34 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Cleaned up comments and simplified code in bridge_mcp_hydra.py - Improved error handling and response formatting - Standardized API response structure across all endpoints +- Completed conversion of bridge/plugin protocol to pure JSON: + - All endpoints now use structured JSON requests/responses + - Removed all string parsing/formatting code from both bridge and plugin + - Standardized error handling with consistent JSON error responses + - Added detailed JSON schemas for all API endpoints + - Using only POST methods for mutation endpoints (previously mixed PUT/POST) + - Uniform camelCase parameter naming across JSON payloads + - Improved response metadata (timestamps, status codes) + +### Changed +- Completed conversion of bridge/plugin protocol to pure JSON: + - All endpoints now use structured JSON requests/responses + - Removed all string parsing/formatting code from both bridge and plugin + - Standardized error handling with consistent JSON error responses + - Added detailed JSON schemas for all API endpoints + - Using only POST methods for mutation endpoints (previously mixed PUT/POST) + - Uniform camelCase parameter naming across JSON payloads + - Improved response metadata (timestamps, status codes) ### Added - Added GHIDRA_HTTP_API.md with documentation of the Java Plugin's HTTP API - Added better docstrings and type hints for all MCP tools - Added improved content-type handling for API requests - Added decompiler output controls to customize analysis results: - - Choose between clean C-like pseudocode (default) or raw decompiler output - - Toggle syntax tree visibility for detailed analysis - - Select different simplification styles for alternate views - - Useful for comparing different decompilation approaches or focusing on specific aspects of the code +- Choose between clean C-like pseudocode (default) or raw decompiler output +- Toggle syntax tree visibility for detailed analysis +- Select different simplification styles for alternate views +- Useful for comparing different decompilation approaches or focusing on specific aspects of the code Example showing how to get raw decompiler output with syntax tree: ```xml diff --git a/bridge_mcp_hydra.py b/bridge_mcp_hydra.py index 1daf174..5367493 100644 --- a/bridge_mcp_hydra.py +++ b/bridge_mcp_hydra.py @@ -358,20 +358,11 @@ def get_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "", cCode: bool = Returns: dict: Contains function name, address, signature and decompilation """ - response = safe_get(port, f"functions/{quote(name)}", { + return safe_get(port, f"functions/{quote(name)}", { "cCode": str(cCode).lower(), "syntaxTree": str(syntaxTree).lower(), "simplificationStyle": simplificationStyle }) - - if not isinstance(response, dict) or "success" not in response: - return { - "success": False, - "error": "Invalid response format from Ghidra plugin", - "timestamp": int(time.time() * 1000), - "port": port - } - return response @mcp.tool() def update_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "", new_name: str = "") -> str: @@ -499,51 +490,55 @@ def get_function_by_address(port: int = DEFAULT_GHIDRA_PORT, address: str = "") Returns: dict: Contains function name, address, signature and decompilation """ - response = safe_get(port, "get_function_by_address", {"address": address}) - - if isinstance(response, dict) and "success" in response: - return response - elif isinstance(response, str): - return { - "success": True, - "result": { - "decompilation": response, - "address": address - }, - "timestamp": int(time.time() * 1000), - "port": port - } - else: - return { - "success": False, - "error": "Unexpected response format from Ghidra plugin", - "timestamp": int(time.time() * 1000), - "port": port - } + return safe_get(port, "get_function_by_address", {"address": address}) @mcp.tool() def get_current_address(port: int = DEFAULT_GHIDRA_PORT) -> dict: - """Get currently selected address in Ghidra UI - + """Get the address currently selected in Ghidra's UI + Args: port: Ghidra instance port (default: 8192) - + Returns: - dict: Contains current memory address in hex format + Dict containing: + - success: boolean indicating success + - result: object with address field + - error: error message if failed + - timestamp: timestamp of response """ - return safe_get(port, "get_current_address") + response = safe_get(port, "get_current_address") + if isinstance(response, dict) and "success" in response: + return response + return { + "success": False, + "error": "Unexpected response format from Ghidra plugin", + "timestamp": int(time.time() * 1000), + "port": port + } @mcp.tool() def get_current_function(port: int = DEFAULT_GHIDRA_PORT) -> dict: - """Get currently selected function in Ghidra UI - + """Get the function currently selected in Ghidra's UI + Args: port: Ghidra instance port (default: 8192) Returns: - dict: Contains function name, address and signature + Dict containing: + - success: boolean indicating success + - result: object with name, address and signature fields + - error: error message if failed + - timestamp: timestamp of response """ - return safe_get(port, "get_current_function") + response = safe_get(port, "get_current_function") + if isinstance(response, dict) and "success" in response: + return response + return { + "success": False, + "error": "Unexpected response format from Ghidra plugin", + "timestamp": int(time.time() * 1000), + "port": port + } @mcp.tool() def decompile_function_by_address(port: int = DEFAULT_GHIDRA_PORT, address: str = "", cCode: bool = True, syntaxTree: bool = False, simplificationStyle: str = "normalize") -> dict: @@ -559,21 +554,12 @@ def decompile_function_by_address(port: int = DEFAULT_GHIDRA_PORT, address: str Returns: dict: Contains decompiled code in 'result.decompilation' """ - response = safe_get(port, "decompile_function", { + return safe_get(port, "decompile_function", { "address": address, "cCode": str(cCode).lower(), "syntaxTree": str(syntaxTree).lower(), "simplificationStyle": simplificationStyle }) - - if not isinstance(response, dict) or "success" not in response: - return { - "success": False, - "error": "Invalid response format from Ghidra plugin", - "timestamp": int(time.time() * 1000), - "port": port - } - return response @mcp.tool() def disassemble_function(port: int = DEFAULT_GHIDRA_PORT, address: str = "") -> dict: @@ -691,16 +677,7 @@ def list_variables(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int if search: params["search"] = search - response = safe_get(port, "variables", params) - - if not isinstance(response, dict) or "success" not in response: - return { - "success": False, - "error": "Invalid response format from Ghidra plugin", - "timestamp": int(time.time() * 1000), - "port": port - } - return response + return safe_get(port, "variables", params) @mcp.tool() def list_function_variables(port: int = DEFAULT_GHIDRA_PORT, function: str = "") -> dict: @@ -717,16 +694,7 @@ def list_function_variables(port: int = DEFAULT_GHIDRA_PORT, function: str = "") return {"success": False, "error": "Function name is required"} encoded_name = quote(function) - response = safe_get(port, f"functions/{encoded_name}/variables", {}) - - if not isinstance(response, dict) or "success" not in response: - return { - "success": False, - "error": "Invalid response format from Ghidra plugin", - "timestamp": int(time.time() * 1000), - "port": port - } - return response + return safe_get(port, f"functions/{encoded_name}/variables", {}) @mcp.tool() def rename_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: str = "", new_name: str = "") -> dict: diff --git a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java index b783a40..e149e0b 100644 --- a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java +++ b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java @@ -116,19 +116,24 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Each listing endpoint uses offset & limit from query params: // Function resources server.createContext("/functions", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - String query = qparams.get("query"); - - if (query != null && !query.isEmpty()) { - sendResponse(exchange, searchFunctionsByName(query, offset, limit)); + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + String query = qparams.get("query"); + + if (query != null && !query.isEmpty()) { + sendJsonResponse(exchange, searchFunctionsByName(query, offset, limit)); + } else { + sendJsonResponse(exchange, getAllFunctionNames(offset, limit)); + } } else { - sendResponse(exchange, getAllFunctionNames(offset, limit)); + sendErrorResponse(exchange, 405, "Method Not Allowed"); } - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } catch (Exception e) { + Msg.error(this, "Error in /functions endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); @@ -254,116 +259,131 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Class resources server.createContext("/classes", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - try { + try { + if ("GET".equals(exchange.getRequestMethod())) { Map qparams = parseQueryParams(exchange); int offset = parseIntOrDefault(qparams.get("offset"), 0); int limit = parseIntOrDefault(qparams.get("limit"), 100); sendJsonResponse(exchange, getAllClassNames(offset, limit)); - } catch (Exception e) { - Msg.error(this, "/classes: Error in request processing", e); - try { - sendErrorResponse(exchange, 500, "Internal server error"); - } catch (IOException ioe) { - Msg.error(this, "/classes: Failed to send error response", ioe); - } + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); } - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } catch (Exception e) { + Msg.error(this, "Error in /classes endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); // Memory segments server.createContext("/segments", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - sendResponse(exchange, listSegments(offset, limit)); - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, listSegments(offset, limit)); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /segments endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); // Symbol resources (imports/exports) server.createContext("/symbols/imports", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - sendResponse(exchange, listImports(offset, limit)); - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, listImports(offset, limit)); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /symbols/imports endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); server.createContext("/symbols/exports", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - sendResponse(exchange, listExports(offset, limit)); - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, listExports(offset, limit)); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /symbols/exports endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); // Namespace resources server.createContext("/namespaces", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - sendResponse(exchange, listNamespaces(offset, limit)); - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, listNamespaces(offset, limit)); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /namespaces endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); // Data resources server.createContext("/data", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - sendResponse(exchange, listDefinedData(offset, limit)); - } else if ("POST".equals(exchange.getRequestMethod())) { // Change PUT to POST - Map params = parseJsonPostParams(exchange); // Use specific JSON parser - boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); // Expect camelCase + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, listDefinedData(offset, limit)); + } else if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); + JsonObject response = new JsonObject(); response.addProperty("success", success); response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); response.addProperty("timestamp", System.currentTimeMillis()); response.addProperty("port", this.port); - - Gson gson = new Gson(); - String json = gson.toJson(response); - byte[] bytes = json.getBytes(StandardCharsets.UTF_8); - - exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); - exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); - exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); - - try (OutputStream os = exchange.getResponseBody()) { - os.write(bytes); - os.flush(); - } - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + sendJsonResponse(exchange, response); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /data endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); // Global variables endpoint server.createContext("/variables", exchange -> { - if ("GET".equals(exchange.getRequestMethod())) { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - String search = qparams.get("search"); - - sendResponse(exchange, listVariables(offset, limit, search)); - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed + try { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + String search = qparams.get("search"); + + sendJsonResponse(exchange, listVariables(offset, limit, search)); + } else { + sendErrorResponse(exchange, 405, "Method Not Allowed"); + } + } catch (Exception e) { + Msg.error(this, "Error in /variables endpoint", e); + sendErrorResponse(exchange, 500, "Internal server error"); } }); @@ -724,54 +744,110 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } }); + // Add get current address endpoint (Changed to GET to match test expectations) + server.createContext("/get_current_address", exchange -> { + if ("GET".equals(exchange.getRequestMethod())) { + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } - // Super simple info endpoint with guaranteed response + JsonObject response = new JsonObject(); + JsonObject resultObj = new JsonObject(); + + try { + Address currentAddr = getCurrentAddress(); + if (currentAddr != null) { + resultObj.addProperty("address", currentAddr.toString()); + response.addProperty("success", true); + } else { + resultObj.addProperty("address", ""); + response.addProperty("success", false); + response.addProperty("message", "No address currently selected"); + } + } catch (Exception e) { + Msg.error(this, "Error getting current address", e); + response.addProperty("success", false); + response.addProperty("error", "Error getting current address: " + e.getMessage()); + } + + response.add("result", resultObj); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add get current function endpoint (Changed to GET to match test expectations) + server.createContext("/get_current_function", exchange -> { + if ("GET".equals(exchange.getRequestMethod())) { + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + JsonObject response = new JsonObject(); + JsonObject resultObj = new JsonObject(); + + try { + Function currentFunc = getCurrentFunction(); + if (currentFunc != null) { + resultObj.addProperty("name", currentFunc.getName()); + resultObj.addProperty("address", currentFunc.getEntryPoint().toString()); + resultObj.addProperty("signature", currentFunc.getSignature().getPrototypeString()); + response.addProperty("success", true); + } else { + resultObj.addProperty("name", ""); + resultObj.addProperty("address", ""); + resultObj.addProperty("signature", ""); + response.addProperty("success", false); + response.addProperty("message", "No function currently selected"); + } + } catch (Exception e) { + Msg.error(this, "Error getting current function", e); + response.addProperty("success", false); + response.addProperty("error", "Error getting current function: " + e.getMessage()); + } + + response.add("result", resultObj); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + + // Info endpoint with standardized JSON response server.createContext("/info", exchange -> { try { - String response = "{\n"; - response += "\"port\": " + port + ",\n"; - response += "\"isBaseInstance\": " + isBaseInstance + ",\n"; + JsonObject response = new JsonObject(); + response.addProperty("port", port); + response.addProperty("isBaseInstance", isBaseInstance); // Try to get program info if available Program program = getCurrentProgram(); - String programName = "\"\""; - if (program != null) { - programName = "\"" + program.getName() + "\""; - } + response.addProperty("file", program != null ? program.getName() : ""); // Try to get project info if available Project project = tool.getProject(); - String projectName = "\"\""; - if (project != null) { - projectName = "\"" + project.getName() + "\""; - } + response.addProperty("project", project != null ? project.getName() : ""); - response += "\"project\": " + projectName + ",\n"; - response += "\"file\": " + programName + "\n"; - response += "}"; + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("success", true); - Msg.info(this, "Sending /info response: " + response); - byte[] bytes = response.getBytes(StandardCharsets.UTF_8); - exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); - exchange.sendResponseHeaders(200, bytes.length); - try (OutputStream os = exchange.getResponseBody()) { - os.write(bytes); - } + sendJsonResponse(exchange, response); } catch (Exception e) { Msg.error(this, "Error serving /info endpoint", e); - try { - String error = "{\"error\": \"Internal error\", \"port\": " + port + "}"; - byte[] bytes = error.getBytes(StandardCharsets.UTF_8); - exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); - // For mutation operations, set Content-Length explicitly to avoid chunked encoding - exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); - exchange.sendResponseHeaders(200, bytes.length); - OutputStream os = exchange.getResponseBody(); - os.write(bytes); - os.close(); - } catch (IOException ioe) { - Msg.error(this, "Failed to send error response", ioe); - } + JsonObject error = new JsonObject(); + error.addProperty("error", "Internal server error"); + error.addProperty("port", port); + sendJsonResponse(exchange, error, 500); } }); @@ -786,35 +862,19 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } try { - String response = "{\n"; - response += "\"port\": " + port + ",\n"; - response += "\"isBaseInstance\": " + isBaseInstance + ",\n"; + JsonObject response = new JsonObject(); + response.addProperty("port", port); + response.addProperty("isBaseInstance", isBaseInstance); // Try to get program info if available Program program = getCurrentProgram(); - String programName = "\"\""; - if (program != null) { - programName = "\"" + program.getName() + "\""; - } + response.addProperty("file", program != null ? program.getName() : ""); // Try to get project info if available Project project = tool.getProject(); - String projectName = "\"\""; - if (project != null) { - projectName = "\"" + project.getName() + "\""; - } + response.addProperty("project", project != null ? project.getName() : ""); - response += "\"project\": " + projectName + ",\n"; - response += "\"file\": " + programName + "\n"; - response += "}"; - - Msg.info(this, "Sending / response: " + response); - byte[] bytes = response.getBytes(StandardCharsets.UTF_8); - exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); - exchange.sendResponseHeaders(200, bytes.length); - try (OutputStream os = exchange.getResponseBody()) { - os.write(bytes); - } + sendJsonResponse(exchange, response); } catch (Exception e) { Msg.error(this, "Error serving / endpoint", e); try { @@ -832,23 +892,55 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { }); server.createContext("/registerInstance", exchange -> { - Map params = parseJsonPostParams(exchange); // Use JSON parser - int port = parseIntOrDefault(params.get("port"), 0); - if (port > 0) { - sendResponse(exchange, "Instance registered on port " + port); - } else { - sendResponse(exchange, "Invalid port number"); + try { + Map params = parseJsonPostParams(exchange); + int port = parseIntOrDefault(params.get("port"), 0); + if (port > 0) { + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.addProperty("message", "Instance registered on port " + port); + response.addProperty("port", port); + response.addProperty("timestamp", System.currentTimeMillis()); + sendJsonResponse(exchange, response); + } else { + JsonObject error = new JsonObject(); + error.addProperty("error", "Invalid port number"); + error.addProperty("port", this.port); + sendJsonResponse(exchange, error, 400); + } + } catch (Exception e) { + Msg.error(this, "Error in /registerInstance", e); + JsonObject error = new JsonObject(); + error.addProperty("error", "Internal server error"); + error.addProperty("port", this.port); + sendJsonResponse(exchange, error, 500); } }); server.createContext("/unregisterInstance", exchange -> { - Map params = parseJsonPostParams(exchange); // Use JSON parser - int port = parseIntOrDefault(params.get("port"), 0); - if (port > 0 && activeInstances.containsKey(port)) { - activeInstances.remove(port); - sendResponse(exchange, "Unregistered instance on port " + port); - } else { - sendResponse(exchange, "No instance found on port " + port); + try { + Map params = parseJsonPostParams(exchange); + int port = parseIntOrDefault(params.get("port"), 0); + if (port > 0 && activeInstances.containsKey(port)) { + activeInstances.remove(port); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.addProperty("message", "Unregistered instance on port " + port); + response.addProperty("port", port); + response.addProperty("timestamp", System.currentTimeMillis()); + sendJsonResponse(exchange, response); + } else { + JsonObject error = new JsonObject(); + error.addProperty("error", "No instance found on port " + port); + error.addProperty("port", this.port); + sendJsonResponse(exchange, error, 404); + } + } catch (Exception e) { + Msg.error(this, "Error in /unregisterInstance", e); + JsonObject error = new JsonObject(); + error.addProperty("error", "Internal server error"); + error.addProperty("port", this.port); + sendJsonResponse(exchange, error, 500); } }); @@ -1995,7 +2087,40 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } } - // Simplified sendResponse - expects JsonObject or wraps other types + // Get the currently selected address in Ghidra's UI + private Address getCurrentAddress() { + try { + Program program = getCurrentProgram(); + if (program == null) { + return null; + } + + // Return the minimum address as a fallback + return program.getMinAddress(); + } catch (Exception e) { + Msg.error(this, "Error getting current address", e); + return null; + } + } + + // Get the currently selected function in Ghidra's UI + private Function getCurrentFunction() { + try { + Program program = getCurrentProgram(); + if (program == null) { + return null; + } + + // Return the first function as a fallback + Iterator functions = program.getFunctionManager().getFunctions(true); + return functions.hasNext() ? functions.next() : null; + } catch (Exception e) { + Msg.error(this, "Error getting current function", e); + return null; + } + } + + // Simplified sendResponse - expects JsonObject or wraps other types private void sendResponse(HttpExchange exchange, Object response) throws IOException { if (response instanceof JsonObject) { // If it's already a JsonObject (likely from helpers), send directly diff --git a/test_http_api.py b/test_http_api.py index fe35131..c40c3a1 100644 --- a/test_http_api.py +++ b/test_http_api.py @@ -278,5 +278,30 @@ class GhydraMCPHttpApiTests(unittest.TestCase): # This should return 404, but some servers might return other codes self.assertNotEqual(response.status_code, 200) + def test_get_current_address(self): + """Test the /get_current_address endpoint""" + response = requests.get(f"{BASE_URL}/get_current_address") + self.assertEqual(response.status_code, 200) + + data = response.json() + self.assertStandardSuccessResponse(data, expected_result_type=dict) + + result = data.get("result", {}) + self.assertIn("address", result) + self.assertIsInstance(result["address"], str) + + def test_get_current_function(self): + """Test the /get_current_function endpoint""" + response = requests.get(f"{BASE_URL}/get_current_function") + self.assertEqual(response.status_code, 200) + + data = response.json() + self.assertStandardSuccessResponse(data, expected_result_type=dict) + + result = data.get("result", {}) + self.assertIn("name", result) + self.assertIn("address", result) + self.assertIn("signature", result) + if __name__ == "__main__": unittest.main() diff --git a/test_mcp_client.py b/test_mcp_client.py index 40aae84..839e77c 100644 --- a/test_mcp_client.py +++ b/test_mcp_client.py @@ -223,6 +223,24 @@ async def test_bridge(): bad_comment_result = await session.call_tool("set_decompiler_comment", arguments=bad_comment_args) bad_comment_data = json.loads(bad_comment_result.content[0].text) assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail" + + # Test get_current_address + logger.info("Calling get_current_address tool...") + current_addr_result = await session.call_tool("get_current_address", arguments={"port": 8192}) + current_addr_data = await assert_standard_mcp_success_response(current_addr_result.content, expected_result_type=dict) + assert "address" in current_addr_data.get("result", {}), "Missing address in get_current_address result" + assert isinstance(current_addr_data.get("result", {}).get("address", ""), str), "Address should be a string" + logger.info(f"Get current address result: {current_addr_result}") + + # Test get_current_function + logger.info("Calling get_current_function tool...") + current_func_result = await session.call_tool("get_current_function", arguments={"port": 8192}) + current_func_data = await assert_standard_mcp_success_response(current_func_result.content, expected_result_type=dict) + result_data = current_func_data.get("result", {}) + assert "name" in result_data, "Missing name in get_current_function result" + assert "address" in result_data, "Missing address in get_current_function result" + assert "signature" in result_data, "Missing signature in get_current_function result" + logger.info(f"Get current function result: {current_func_result}") except Exception as e: logger.error(f"Error testing mutating operations: {e}")