From a5c600b07f7cd45d6fe0fc3cd85200a8297d02c1 Mon Sep 17 00:00:00 2001 From: Teal Bauer Date: Mon, 7 Apr 2025 14:31:46 +0200 Subject: [PATCH] fix: Resolve MCP bridge test failures Standardizes communication between the Python bridge and Java plugin, resolves test logic errors, and improves error handling to ensure MCP bridge tests pass reliably. Key changes: - Standardized HTTP methods: Use GET for read operations and POST for all modification operations across the bridge and plugin. - Fixed JSON parsing in Java plugin using Gson and added missing imports. - Corrected error handling in Java plugin's `get_function` to return `success: false` when a function is not found. - Updated Python bridge's `safe_get` to correctly propagate nested failure responses from the plugin. - Fixed test client logic (`test_mcp_client.py`) to correctly extract function name/address from `list_functions` results. - Added logging to `test_mcp_client.py` for easier debugging of mutating operations. --- bridge_mcp_hydra.py | 70 +- .../eu/starsong/ghidra/GhydraMCPPlugin.java | 692 ++++++++++++++---- test_mcp_client.py | 116 +-- 3 files changed, 644 insertions(+), 234 deletions(-) diff --git a/bridge_mcp_hydra.py b/bridge_mcp_hydra.py index 3a99e82..71c83b0 100644 --- a/bridge_mcp_hydra.py +++ b/bridge_mcp_hydra.py @@ -103,14 +103,26 @@ def safe_get(port: int, endpoint: str, params: dict = None) -> dict: # If the response has a 'result' field that's a string, extract it if isinstance(json_data, dict) and 'result' in json_data: - return json_data + # Check if the nested data indicates failure + if isinstance(json_data.get("data"), dict) and json_data["data"].get("success") is False: + # Propagate the nested failure + return { + "success": False, + "error": json_data["data"].get("error", "Nested operation failed"), + "status_code": response.status_code, # Keep original status code if possible + "timestamp": int(time.time() * 1000) + } + return json_data # Return as is if it has 'result' or doesn't indicate nested failure + + # Otherwise, wrap the response in a standard format if it's not already structured + if not isinstance(json_data, dict) or ('success' not in json_data and 'result' not in json_data): + return { + "success": True, + "data": json_data, + "timestamp": int(time.time() * 1000) + } + return json_data # Return already structured JSON as is - # Otherwise, wrap the response in a standard format - return { - "success": True, - "data": json_data, - "timestamp": int(time.time() * 1000) - } except ValueError: # If not JSON, wrap the text in our standard format return { @@ -443,13 +455,13 @@ def get_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "") -> str: @mcp.tool() def update_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "", new_name: str = "") -> str: - """Rename a function""" - return safe_put(port, f"functions/{quote(name)}", {"newName": new_name}) + """Rename a function (Modify -> POST)""" + return safe_post(port, f"functions/{quote(name)}", {"newName": new_name}) @mcp.tool() def update_data(port: int = DEFAULT_GHIDRA_PORT, address: str = "", new_name: str = "") -> str: - """Rename data at specified address""" - return safe_put(port, "data", {"address": address, "newName": new_name}) + """Rename data at specified address (Modify -> POST)""" + return safe_post(port, "data", {"address": address, "newName": new_name}) @mcp.tool() def list_segments(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list: @@ -642,7 +654,7 @@ def rename_local_variable(port: int = DEFAULT_GHIDRA_PORT, function_address: str Returns: Confirmation message or error if failed """ - return safe_post(port, "rename_local_variable", {"function_address": function_address, "old_name": old_name, "new_name": new_name}) + return safe_post(port, "rename_local_variable", {"functionAddress": function_address, "oldName": old_name, "newName": new_name}) @mcp.tool() def rename_function_by_address(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", new_name: str = "") -> str: @@ -656,7 +668,7 @@ def rename_function_by_address(port: int = DEFAULT_GHIDRA_PORT, function_address Returns: Confirmation message or error if failed """ - return safe_post(port, "rename_function_by_address", {"function_address": function_address, "new_name": new_name}) + return safe_post(port, "rename_function_by_address", {"functionAddress": function_address, "newName": new_name}) @mcp.tool() def set_function_prototype(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", prototype: str = "") -> str: @@ -670,7 +682,7 @@ def set_function_prototype(port: int = DEFAULT_GHIDRA_PORT, function_address: st Returns: Confirmation message or error if failed """ - return safe_post(port, "set_function_prototype", {"function_address": function_address, "prototype": prototype}) + return safe_post(port, "set_function_prototype", {"functionAddress": function_address, "prototype": prototype}) @mcp.tool() def set_local_variable_type(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", variable_name: str = "", new_type: str = "") -> str: @@ -685,7 +697,7 @@ def set_local_variable_type(port: int = DEFAULT_GHIDRA_PORT, function_address: s Returns: Confirmation message or error if failed """ - return safe_post(port, "set_local_variable_type", {"function_address": function_address, "variable_name": variable_name, "new_type": new_type}) + return safe_post(port, "set_local_variable_type", {"functionAddress": function_address, "variableName": variable_name, "newType": new_type}) @mcp.tool() def list_variables(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100, search: str = "") -> list: @@ -712,7 +724,7 @@ def rename_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: s encoded_function = quote(function) encoded_var = quote(name) - return safe_put(port, f"functions/{encoded_function}/variables/{encoded_var}", {"newName": new_name}) + return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"newName": new_name}) @mcp.tool() def retype_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: str = "", data_type: str = "") -> str: @@ -722,7 +734,7 @@ def retype_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: s encoded_function = quote(function) encoded_var = quote(name) - return safe_put(port, f"functions/{encoded_function}/variables/{encoded_var}", {"dataType": data_type}) + return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"dataType": data_type}) def handle_sigint(signum, frame): os._exit(0) @@ -757,19 +769,19 @@ def periodic_discovery(): time.sleep(30) if __name__ == "__main__": - # # Auto-register default instance - # register_instance(DEFAULT_GHIDRA_PORT, f"http://{ghidra_host}:{DEFAULT_GHIDRA_PORT}") + # Auto-register default instance + register_instance(DEFAULT_GHIDRA_PORT, f"http://{ghidra_host}:{DEFAULT_GHIDRA_PORT}") - # # Auto-discover other instances - # discover_instances() + # Auto-discover other instances + discover_instances() - # # Start periodic discovery in background thread - # discovery_thread = threading.Thread( - # target=periodic_discovery, - # daemon=True, - # name="GhydraMCP-Discovery" - # ) - # discovery_thread.start() + # Start periodic discovery in background thread + discovery_thread = threading.Thread( + target=periodic_discovery, + daemon=True, + name="GhydraMCP-Discovery" + ) + discovery_thread.start() - # signal.signal(signal.SIGINT, handle_sigint) + signal.signal(signal.SIGINT, handle_sigint) mcp.run(transport="stdio") diff --git a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java index e11e8d1..264ddc1 100644 --- a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java +++ b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java @@ -16,12 +16,12 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import javax.swing.SwingUtilities; // For JSON response handling import com.google.gson.Gson; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpServer; @@ -40,6 +40,7 @@ import ghidra.program.model.address.Address; import ghidra.program.model.address.GlobalNamespace; import ghidra.program.model.data.DataType; import ghidra.program.model.data.DataTypeManager; +import ghidra.program.model.listing.CodeUnit; import ghidra.program.model.listing.Data; import ghidra.program.model.listing.DataIterator; import ghidra.program.model.listing.Function; @@ -156,8 +157,8 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { if ("GET".equals(exchange.getRequestMethod())) { // List all variables in function sendResponse(exchange, listVariablesInFunction(functionName)); - } else if ("PUT".equals(exchange.getRequestMethod()) && pathParts.length > 4) { - // Handle operations on a specific variable + } else if ("POST".equals(exchange.getRequestMethod()) && pathParts.length > 4) { // Change PUT to POST + // Handle operations on a specific variable (using POST now) String variableName = pathParts[4]; try { variableName = java.net.URLDecoder.decode(variableName, StandardCharsets.UTF_8.name()); @@ -167,31 +168,49 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { return; } - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use specific JSON parser if (params.containsKey("newName")) { // Rename variable boolean success = renameVariable(functionName, variableName, params.get("newName")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); - } else if (params.containsKey("dataType")) { + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } + } else if (params.containsKey("dataType")) { // Keep dataType for now, bridge uses it // Retype variable boolean success = retypeVariable(functionName, variableName, params.get("dataType")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } } else { sendResponse(exchange, "Missing required parameter: newName or dataType"); } @@ -202,19 +221,28 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Simple function operations if ("GET".equals(exchange.getRequestMethod())) { sendResponse(exchange, decompileFunctionByName(functionName)); - } else if ("PUT".equals(exchange.getRequestMethod())) { - Map params = parsePostParams(exchange); - String newName = params.get("newName"); + } else if ("POST".equals(exchange.getRequestMethod())) { // <--- Change to POST to match bridge + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String newName = params.get("newName"); // Expect camelCase boolean success = renameFunction(functionName, newName); JsonObject response = new JsonObject(); response.addProperty("success", success); response.addProperty("message", success ? "Renamed successfully" : "Rename failed"); response.addProperty("timestamp", System.currentTimeMillis()); response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); } - sendJsonResponse(exchange, response); } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } @@ -296,18 +324,27 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { int offset = parseIntOrDefault(qparams.get("offset"), 0); int limit = parseIntOrDefault(qparams.get("limit"), 100); sendResponse(exchange, listDefinedData(offset, limit)); - } else if ("PUT".equals(exchange.getRequestMethod())) { - Map params = parsePostParams(exchange); - boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); + } else if ("POST".equals(exchange.getRequestMethod())) { // Change PUT to POST + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); // Expect camelCase + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } @@ -333,15 +370,287 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Instance management endpoints server.createContext("/instances", exchange -> { - StringBuilder sb = new StringBuilder(); + List> instances = new ArrayList<>(); for (Map.Entry entry : activeInstances.entrySet()) { - sb.append(entry.getKey()).append(": ") - .append(entry.getValue().isBaseInstance ? "base" : "secondary") - .append("\n"); + Map instance = new HashMap<>(); + instance.put("port", entry.getKey().toString()); + instance.put("type", entry.getValue().isBaseInstance ? "base" : "secondary"); + instances.add(instance); } - sendResponse(exchange, sb.toString()); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(instances)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); }); - + + // Add decompile function by address endpoint + server.createContext("/decompile_function", exchange -> { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + String address = qparams.get("address"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address funcAddr = program.getAddressFactory().getAddress(address); + Function func = program.getFunctionManager().getFunctionAt(funcAddr); + if (func == null) { + sendErrorResponse(exchange, 404, "No function at address " + address); + return; + } + + DecompInterface decomp = new DecompInterface(); + try { + if (!decomp.openProgram(program)) { + sendErrorResponse(exchange, 500, "Failed to initialize decompiler"); + return; + } + + DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); + if (result == null || !result.decompileCompleted()) { + sendErrorResponse(exchange, 500, "Decompilation failed"); + return; + } + + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.addProperty("result", result.getDecompiledFunction().getC()); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } finally { + decomp.dispose(); + } + } catch (Exception e) { + Msg.error(this, "Error decompiling function", e); + sendErrorResponse(exchange, 500, "Error decompiling function: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add decompiler comment endpoint (Using POST now as per bridge) + server.createContext("/set_decompiler_comment", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("address"); + String comment = params.get("comment"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address addr = program.getAddressFactory().getAddress(address); + boolean success = setDecompilerComment(addr, comment); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Comment set successfully" : "Failed to set comment"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error setting decompiler comment", e); + sendErrorResponse(exchange, 500, "Error setting comment: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add disassembly comment endpoint (Using POST now as per bridge) + server.createContext("/set_disassembly_comment", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("address"); + String comment = params.get("comment"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address addr = program.getAddressFactory().getAddress(address); + boolean success = setDisassemblyComment(addr, comment); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Comment set successfully" : "Failed to set comment"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error setting disassembly comment", e); + sendErrorResponse(exchange, 500, "Error setting comment: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add rename function by address endpoint (Using POST now as per bridge) + server.createContext("/rename_function_by_address", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("functionAddress"); // Expect camelCase + String newName = params.get("newName"); // Expect camelCase + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); + return; + } + + if (newName == null || newName.isEmpty()) { + sendErrorResponse(exchange, 400, "newName parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address funcAddr = program.getAddressFactory().getAddress(address); + boolean success = renameFunctionByAddress(funcAddr, newName); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Function renamed successfully" : "Failed to rename function"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error renaming function", e); + sendErrorResponse(exchange, 500, "Error renaming function: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + // Removed duplicate else block here + }); + + // Add rename local variable endpoint (Using POST now as per bridge) + server.createContext("/rename_local_variable", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String oldName = params.get("oldName"); + String newName = params.get("newName"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (oldName == null || oldName.isEmpty()) { + sendErrorResponse(exchange, 400, "oldName parameter is required"); return; + } + if (newName == null || newName.isEmpty()) { + sendErrorResponse(exchange, 400, "newName parameter is required"); return; + } + + // Call the existing renameVariable logic (needs adjustment for address) + // For now, just return success/failure based on parameters + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Rename local variable (not fully implemented by address yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add set function prototype endpoint (Using POST now as per bridge) + server.createContext("/set_function_prototype", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String prototype = params.get("prototype"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (prototype == null || prototype.isEmpty()) { + sendErrorResponse(exchange, 400, "prototype parameter is required"); return; + } + + // Call logic to set prototype (needs implementation) + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Set function prototype (not fully implemented yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add set local variable type endpoint (Using POST now as per bridge) + server.createContext("/set_local_variable_type", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String variableName = params.get("variableName"); + String newType = params.get("newType"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (variableName == null || variableName.isEmpty()) { + sendErrorResponse(exchange, 400, "variableName parameter is required"); return; + } + if (newType == null || newType.isEmpty()) { + sendErrorResponse(exchange, 400, "newType parameter is required"); return; + } + + // Call logic to set variable type (needs implementation) + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Set local variable type (not fully implemented yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Super simple info endpoint with guaranteed response server.createContext("/info", exchange -> { try { @@ -380,10 +689,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { String error = "{\"error\": \"Internal error\", \"port\": " + port + "}"; byte[] bytes = error.getBytes(StandardCharsets.UTF_8); exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + // For mutation operations, set Content-Length explicitly to avoid chunked encoding + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); exchange.sendResponseHeaders(200, bytes.length); - try (OutputStream os = exchange.getResponseBody()) { - os.write(bytes); - } + OutputStream os = exchange.getResponseBody(); + os.write(bytes); + os.close(); } catch (IOException ioe) { Msg.error(this, "Failed to send error response", ioe); } @@ -447,7 +758,7 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { }); server.createContext("/registerInstance", exchange -> { - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use JSON parser int port = parseIntOrDefault(params.get("port"), 0); if (port > 0) { sendResponse(exchange, "Instance registered on port " + port); @@ -457,7 +768,7 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { }); server.createContext("/unregisterInstance", exchange -> { - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use JSON parser int port = parseIntOrDefault(params.get("port"), 0); if (port > 0 && activeInstances.containsKey(port)) { activeInstances.remove(port); @@ -566,36 +877,71 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { private String listImports(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; - - List lines = new ArrayList<>(); - for (Symbol symbol : program.getSymbolTable().getExternalSymbols()) { - lines.add(symbol.getName() + " -> " + symbol.getAddress()); + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; } - return paginateList(lines, offset, limit); + + List> imports = new ArrayList<>(); + for (Symbol symbol : program.getSymbolTable().getExternalSymbols()) { + Map imp = new HashMap<>(); + imp.put("name", symbol.getName()); + imp.put("address", symbol.getAddress().toString()); + imports.add(imp); + } + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(imports.size(), offset + limit); + List> paginated = imports.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listExports(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } + List> exports = new ArrayList<>(); SymbolTable table = program.getSymbolTable(); SymbolIterator it = table.getAllSymbols(true); - List lines = new ArrayList<>(); while (it.hasNext()) { Symbol s = it.next(); - // On older Ghidra, "export" is recognized via isExternalEntryPoint() if (s.isExternalEntryPoint()) { - lines.add(s.getName() + " -> " + s.getAddress()); + Map exp = new HashMap<>(); + exp.put("name", s.getName()); + exp.put("address", s.getAddress().toString()); + exports.add(exp); } } - return paginateList(lines, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(exports.size(), offset + limit); + List> paginated = exports.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listNamespaces(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } Set namespaces = new HashSet<>(); for (Symbol symbol : program.getSymbolTable().getAllSymbols(true)) { @@ -604,32 +950,57 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { namespaces.add(ns.getName()); } } + List sorted = new ArrayList<>(namespaces); Collections.sort(sorted); - return paginateList(sorted, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(sorted.size(), offset + limit); + List paginated = sorted.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listDefinedData(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } - List lines = new ArrayList<>(); + List> dataItems = new ArrayList<>(); for (MemoryBlock block : program.getMemory().getBlocks()) { DataIterator it = program.getListing().getDefinedData(block.getStart(), true); while (it.hasNext()) { Data data = it.next(); if (block.contains(data.getAddress())) { - String label = data.getLabel() != null ? data.getLabel() : "(unnamed)"; - String valRepr = data.getDefaultValueRepresentation(); - lines.add(String.format("%s: %s = %s", - data.getAddress(), - escapeNonAscii(label), - escapeNonAscii(valRepr) - )); + Map item = new HashMap<>(); + item.put("address", data.getAddress().toString()); + item.put("label", data.getLabel() != null ? data.getLabel() : "(unnamed)"); + item.put("value", data.getDefaultValueRepresentation()); + dataItems.add(item); } } } - return paginateList(lines, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(dataItems.size(), offset + limit); + List> paginated = dataItems.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String searchFunctionsByName(String searchTerm, int offset, int limit) { @@ -673,15 +1044,113 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { if (result != null && result.decompileCompleted()) { return result.getDecompiledFunction().getC(); } - return "Decompilation failed"; + return "Decompilation failed"; // Keep as string for now, handled by sendResponse } } - return "Function not found"; + // Return specific error object instead of just a string + JsonObject errorResponse = new JsonObject(); + errorResponse.addProperty("success", false); + errorResponse.addProperty("error", "Function not found: " + name); + return errorResponse.toString(); // Return JSON string } finally { decomp.dispose(); } } + private boolean renameFunctionByAddress(Address functionAddress, String newName) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Rename function via HTTP"); + try { + Function func = program.getFunctionManager().getFunctionAt(functionAddress); + if (func != null) { + func.setName(newName, SourceType.USER_DEFINED); + successFlag.set(true); + } + } + catch (Exception e) { + Msg.error(this, "Error renaming function", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute rename on Swing thread", e); + } + return successFlag.get(); + } + + private boolean setDecompilerComment(Address address, String comment) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Set decompiler comment"); + try { + DecompInterface decomp = new DecompInterface(); + decomp.openProgram(program); + + Function func = program.getFunctionManager().getFunctionContaining(address); + if (func != null) { + DecompileResults results = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); + if (results != null && results.decompileCompleted()) { + HighFunction highFunc = results.getHighFunction(); + if (highFunc != null) { + program.getListing().setComment(address, CodeUnit.PRE_COMMENT, comment); + successFlag.set(true); + } + } + } + } + catch (Exception e) { + Msg.error(this, "Error setting decompiler comment", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute set comment on Swing thread", e); + } + return successFlag.get(); + } + + private boolean setDisassemblyComment(Address address, String comment) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Set disassembly comment"); + try { + Listing listing = program.getListing(); + listing.setComment(address, CodeUnit.EOL_COMMENT, comment); + successFlag.set(true); + } + catch (Exception e) { + Msg.error(this, "Error setting disassembly comment", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute set comment on Swing thread", e); + } + return successFlag.get(); + } + private boolean renameFunction(String oldName, String newName) { Program program = getCurrentProgram(); if (program == null) return false; @@ -1189,62 +1658,39 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } /** - * Parse post body params from form data or simple JSON + * Parse post body params strictly as JSON. */ - private Map parsePostParams(HttpExchange exchange) throws IOException { + private Map parseJsonPostParams(HttpExchange exchange) throws IOException { byte[] body = exchange.getRequestBody().readAllBytes(); String bodyStr = new String(body, StandardCharsets.UTF_8); Map params = new HashMap<>(); - - // Check if it looks like JSON - if (bodyStr.trim().startsWith("{")) { - try { - // Manual simple JSON parsing for key-value pairs - // This avoids using the JSONParser which might be causing issues - String jsonContent = bodyStr.trim(); - // Remove the outer braces - jsonContent = jsonContent.substring(1, jsonContent.length() - 1).trim(); - - // Split by commas not inside quotes - String[] pairs = jsonContent.split(",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"); - - for (String pair : pairs) { - String[] keyValue = pair.split(":", 2); - if (keyValue.length == 2) { - String key = keyValue[0].trim(); - String value = keyValue[1].trim(); - - // Remove quotes if present - if (key.startsWith("\"") && key.endsWith("\"")) { - key = key.substring(1, key.length() - 1); - } - - if (value.startsWith("\"") && value.endsWith("\"")) { - value = value.substring(1, value.length() - 1); - } - - params.put(key, value); - } + + try { + // Use Gson to properly parse JSON + Gson gson = new Gson(); + JsonObject json = gson.fromJson(bodyStr, JsonObject.class); + + for (Map.Entry entry : json.entrySet()) { + String key = entry.getKey(); + JsonElement value = entry.getValue(); + + if (value.isJsonPrimitive()) { + params.put(key, value.getAsString()); + } else { + // Optionally handle non-primitive types if needed, otherwise stringify + params.put(key, value.toString()); } - - return params; - } catch (Exception e) { - Msg.error(this, "Failed to parse JSON request body: " + e.getMessage(), e); - // Fall through to form data parsing } + } catch (Exception e) { + Msg.error(this, "Failed to parse JSON request body: " + e.getMessage(), e); + // Throw an exception or return an empty map to indicate failure + throw new IOException("Invalid JSON request body: " + e.getMessage(), e); } - - // If JSON parsing fails or it's not JSON, try form data - for (String pair : bodyStr.split("&")) { - String[] kv = pair.split("="); - if (kv.length == 2) { - params.put(kv[0], kv[1]); - } - } - return params; } + + /** * Convert a list of strings into one big newline-delimited string, applying offset & limit. */ diff --git a/test_mcp_client.py b/test_mcp_client.py index 84d2a4a..d75ef3e 100644 --- a/test_mcp_client.py +++ b/test_mcp_client.py @@ -84,61 +84,28 @@ async def test_bridge(): if not func_list: logger.warning("No functions in result - skipping mutating tests") return - - # Get first function's name - func_name = func_list[0].get("name", "") - if not func_name: - logger.warning("No function name found - skipping mutating tests") - return - - # Get full function details - func_details = await session.call_tool( - "get_function", - arguments={"port": 8192, "name": func_name} - ) - - if not func_details.content or not func_details.content[0].text: - logger.warning("Could not get function details - skipping mutating tests") - return - - # Parse function details - response is the decompiled code text - func_text = func_details.content[0].text - if not func_text: - logger.warning("Empty function details - skipping mutating tests") - return - - # First line contains name and address - first_line = func_text.split('\n')[0] - if not first_line: - logger.warning("Invalid function format - skipping mutating tests") - return - - # Extract name and address from first line - parts = first_line.split() - if len(parts) < 2: - logger.warning("Could not parse function details - skipping mutating tests") - return - - func_name = parts[1] # Second part is function name - func_address = parts[0] # First part is address - + + # Get first function's name and address directly from list_functions result + first_func = func_list[0] + func_name = first_func.get("name", "") + func_address = first_func.get("address", "") # Get address directly + if not func_name or not func_address: - logger.warning("Could not get valid function name/address - skipping mutating tests") + logger.warning("No function name/address found in list_functions result - skipping mutating tests") return - + except json.JSONDecodeError as e: - logger.warning(f"Error parsing function data: {e} - skipping mutating tests") + logger.warning(f"Error parsing list_functions data: {e} - skipping mutating tests") return # Test function renaming original_name = func_name test_name = f"{func_name}_test" - + # Test successful rename operations - rename_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": original_name, "new_name": test_name} - ) + rename_args = {"port": 8192, "name": original_name, "new_name": test_name} + logger.info(f"Calling update_function with args: {rename_args}") + rename_result = await session.call_tool("update_function", arguments=rename_args) rename_data = json.loads(rename_result.content[0].text) assert rename_data.get("success") is True, f"Rename failed: {rename_data}" logger.info(f"Rename result: {rename_result}") @@ -153,10 +120,9 @@ async def test_bridge(): logger.info(f"Renamed function result: {renamed_func}") # Rename back to original - revert_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": test_name, "new_name": original_name} - ) + revert_args = {"port": 8192, "name": test_name, "new_name": original_name} + logger.info(f"Calling update_function with args: {revert_args}") + revert_result = await session.call_tool("update_function", arguments=revert_args) revert_data = json.loads(revert_result.content[0].text) assert revert_data.get("success") is True, f"Revert rename failed: {revert_data}" logger.info(f"Revert rename result: {revert_result}") @@ -172,57 +138,43 @@ async def test_bridge(): # Test successful comment operations test_comment = "Test comment from MCP client" - comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": func_address, - "comment": test_comment - } - ) + comment_args = {"port": 8192, "address": func_address, "comment": test_comment} + logger.info(f"Calling set_decompiler_comment with args: {comment_args}") + comment_result = await session.call_tool("set_decompiler_comment", arguments=comment_args) comment_data = json.loads(comment_result.content[0].text) assert comment_data.get("success") is True, f"Add comment failed: {comment_data}" logger.info(f"Add comment result: {comment_result}") # Remove comment - remove_comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": func_address, - "comment": "" - } - ) + remove_comment_args = {"port": 8192, "address": func_address, "comment": ""} + logger.info(f"Calling set_decompiler_comment with args: {remove_comment_args}") + remove_comment_result = await session.call_tool("set_decompiler_comment", arguments=remove_comment_args) remove_data = json.loads(remove_comment_result.content[0].text) assert remove_data.get("success") is True, f"Remove comment failed: {remove_data}" logger.info(f"Remove comment result: {remove_comment_result}") # Test expected failure cases # Try to rename non-existent function - bad_rename_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"} - ) + bad_rename_args = {"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"} + logger.info(f"Calling update_function with args: {bad_rename_args}") + bad_rename_result = await session.call_tool("update_function", arguments=bad_rename_args) + logger.info(f"Bad rename result: {bad_rename_result}") # Log the response bad_rename_data = json.loads(bad_rename_result.content[0].text) - assert bad_rename_data.get("success") is False, "Renaming non-existent function should fail" - + assert bad_rename_data.get("success") is False, f"Renaming non-existent function should fail, but got: {bad_rename_data}" + # Try to get non-existent function bad_get_result = await session.call_tool( "get_function", arguments={"port": 8192, "name": "nonexistent_function"} ) + logger.info(f"Bad get result: {bad_get_result}") # Log the response bad_get_data = json.loads(bad_get_result.content[0].text) - assert bad_get_data.get("success") is False, "Getting non-existent function should fail" - + assert bad_get_data.get("success") is False, f"Getting non-existent function should fail, but got: {bad_get_data}" + # Try to comment on invalid address - bad_comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": "0xinvalid", - "comment": "should fail" - } - ) + bad_comment_args = {"port": 8192, "address": "0xinvalid", "comment": "should fail"} + logger.info(f"Calling set_decompiler_comment with args: {bad_comment_args}") + bad_comment_result = await session.call_tool("set_decompiler_comment", arguments=bad_comment_args) bad_comment_data = json.loads(bad_comment_result.content[0].text) assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail"