diff --git a/bridge_mcp_hydra.py b/bridge_mcp_hydra.py index 3a99e82..71c83b0 100644 --- a/bridge_mcp_hydra.py +++ b/bridge_mcp_hydra.py @@ -103,14 +103,26 @@ def safe_get(port: int, endpoint: str, params: dict = None) -> dict: # If the response has a 'result' field that's a string, extract it if isinstance(json_data, dict) and 'result' in json_data: - return json_data + # Check if the nested data indicates failure + if isinstance(json_data.get("data"), dict) and json_data["data"].get("success") is False: + # Propagate the nested failure + return { + "success": False, + "error": json_data["data"].get("error", "Nested operation failed"), + "status_code": response.status_code, # Keep original status code if possible + "timestamp": int(time.time() * 1000) + } + return json_data # Return as is if it has 'result' or doesn't indicate nested failure + + # Otherwise, wrap the response in a standard format if it's not already structured + if not isinstance(json_data, dict) or ('success' not in json_data and 'result' not in json_data): + return { + "success": True, + "data": json_data, + "timestamp": int(time.time() * 1000) + } + return json_data # Return already structured JSON as is - # Otherwise, wrap the response in a standard format - return { - "success": True, - "data": json_data, - "timestamp": int(time.time() * 1000) - } except ValueError: # If not JSON, wrap the text in our standard format return { @@ -443,13 +455,13 @@ def get_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "") -> str: @mcp.tool() def update_function(port: int = DEFAULT_GHIDRA_PORT, name: str = "", new_name: str = "") -> str: - """Rename a function""" - return safe_put(port, f"functions/{quote(name)}", {"newName": new_name}) + """Rename a function (Modify -> POST)""" + return safe_post(port, f"functions/{quote(name)}", {"newName": new_name}) @mcp.tool() def update_data(port: int = DEFAULT_GHIDRA_PORT, address: str = "", new_name: str = "") -> str: - """Rename data at specified address""" - return safe_put(port, "data", {"address": address, "newName": new_name}) + """Rename data at specified address (Modify -> POST)""" + return safe_post(port, "data", {"address": address, "newName": new_name}) @mcp.tool() def list_segments(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100) -> list: @@ -642,7 +654,7 @@ def rename_local_variable(port: int = DEFAULT_GHIDRA_PORT, function_address: str Returns: Confirmation message or error if failed """ - return safe_post(port, "rename_local_variable", {"function_address": function_address, "old_name": old_name, "new_name": new_name}) + return safe_post(port, "rename_local_variable", {"functionAddress": function_address, "oldName": old_name, "newName": new_name}) @mcp.tool() def rename_function_by_address(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", new_name: str = "") -> str: @@ -656,7 +668,7 @@ def rename_function_by_address(port: int = DEFAULT_GHIDRA_PORT, function_address Returns: Confirmation message or error if failed """ - return safe_post(port, "rename_function_by_address", {"function_address": function_address, "new_name": new_name}) + return safe_post(port, "rename_function_by_address", {"functionAddress": function_address, "newName": new_name}) @mcp.tool() def set_function_prototype(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", prototype: str = "") -> str: @@ -670,7 +682,7 @@ def set_function_prototype(port: int = DEFAULT_GHIDRA_PORT, function_address: st Returns: Confirmation message or error if failed """ - return safe_post(port, "set_function_prototype", {"function_address": function_address, "prototype": prototype}) + return safe_post(port, "set_function_prototype", {"functionAddress": function_address, "prototype": prototype}) @mcp.tool() def set_local_variable_type(port: int = DEFAULT_GHIDRA_PORT, function_address: str = "", variable_name: str = "", new_type: str = "") -> str: @@ -685,7 +697,7 @@ def set_local_variable_type(port: int = DEFAULT_GHIDRA_PORT, function_address: s Returns: Confirmation message or error if failed """ - return safe_post(port, "set_local_variable_type", {"function_address": function_address, "variable_name": variable_name, "new_type": new_type}) + return safe_post(port, "set_local_variable_type", {"functionAddress": function_address, "variableName": variable_name, "newType": new_type}) @mcp.tool() def list_variables(port: int = DEFAULT_GHIDRA_PORT, offset: int = 0, limit: int = 100, search: str = "") -> list: @@ -712,7 +724,7 @@ def rename_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: s encoded_function = quote(function) encoded_var = quote(name) - return safe_put(port, f"functions/{encoded_function}/variables/{encoded_var}", {"newName": new_name}) + return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"newName": new_name}) @mcp.tool() def retype_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: str = "", data_type: str = "") -> str: @@ -722,7 +734,7 @@ def retype_variable(port: int = DEFAULT_GHIDRA_PORT, function: str = "", name: s encoded_function = quote(function) encoded_var = quote(name) - return safe_put(port, f"functions/{encoded_function}/variables/{encoded_var}", {"dataType": data_type}) + return safe_post(port, f"functions/{encoded_function}/variables/{encoded_var}", {"dataType": data_type}) def handle_sigint(signum, frame): os._exit(0) @@ -757,19 +769,19 @@ def periodic_discovery(): time.sleep(30) if __name__ == "__main__": - # # Auto-register default instance - # register_instance(DEFAULT_GHIDRA_PORT, f"http://{ghidra_host}:{DEFAULT_GHIDRA_PORT}") + # Auto-register default instance + register_instance(DEFAULT_GHIDRA_PORT, f"http://{ghidra_host}:{DEFAULT_GHIDRA_PORT}") - # # Auto-discover other instances - # discover_instances() + # Auto-discover other instances + discover_instances() - # # Start periodic discovery in background thread - # discovery_thread = threading.Thread( - # target=periodic_discovery, - # daemon=True, - # name="GhydraMCP-Discovery" - # ) - # discovery_thread.start() + # Start periodic discovery in background thread + discovery_thread = threading.Thread( + target=periodic_discovery, + daemon=True, + name="GhydraMCP-Discovery" + ) + discovery_thread.start() - # signal.signal(signal.SIGINT, handle_sigint) + signal.signal(signal.SIGINT, handle_sigint) mcp.run(transport="stdio") diff --git a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java index e11e8d1..264ddc1 100644 --- a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java +++ b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java @@ -16,12 +16,12 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import javax.swing.SwingUtilities; // For JSON response handling import com.google.gson.Gson; +import com.google.gson.JsonElement; import com.google.gson.JsonObject; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpServer; @@ -40,6 +40,7 @@ import ghidra.program.model.address.Address; import ghidra.program.model.address.GlobalNamespace; import ghidra.program.model.data.DataType; import ghidra.program.model.data.DataTypeManager; +import ghidra.program.model.listing.CodeUnit; import ghidra.program.model.listing.Data; import ghidra.program.model.listing.DataIterator; import ghidra.program.model.listing.Function; @@ -156,8 +157,8 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { if ("GET".equals(exchange.getRequestMethod())) { // List all variables in function sendResponse(exchange, listVariablesInFunction(functionName)); - } else if ("PUT".equals(exchange.getRequestMethod()) && pathParts.length > 4) { - // Handle operations on a specific variable + } else if ("POST".equals(exchange.getRequestMethod()) && pathParts.length > 4) { // Change PUT to POST + // Handle operations on a specific variable (using POST now) String variableName = pathParts[4]; try { variableName = java.net.URLDecoder.decode(variableName, StandardCharsets.UTF_8.name()); @@ -167,31 +168,49 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { return; } - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use specific JSON parser if (params.containsKey("newName")) { // Rename variable boolean success = renameVariable(functionName, variableName, params.get("newName")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); - } else if (params.containsKey("dataType")) { + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } + } else if (params.containsKey("dataType")) { // Keep dataType for now, bridge uses it // Retype variable boolean success = retypeVariable(functionName, variableName, params.get("dataType")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } } else { sendResponse(exchange, "Missing required parameter: newName or dataType"); } @@ -202,19 +221,28 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Simple function operations if ("GET".equals(exchange.getRequestMethod())) { sendResponse(exchange, decompileFunctionByName(functionName)); - } else if ("PUT".equals(exchange.getRequestMethod())) { - Map params = parsePostParams(exchange); - String newName = params.get("newName"); + } else if ("POST".equals(exchange.getRequestMethod())) { // <--- Change to POST to match bridge + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String newName = params.get("newName"); // Expect camelCase boolean success = renameFunction(functionName, newName); JsonObject response = new JsonObject(); response.addProperty("success", success); response.addProperty("message", success ? "Renamed successfully" : "Rename failed"); response.addProperty("timestamp", System.currentTimeMillis()); response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); } - sendJsonResponse(exchange, response); } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } @@ -296,18 +324,27 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { int offset = parseIntOrDefault(qparams.get("offset"), 0); int limit = parseIntOrDefault(qparams.get("limit"), 100); sendResponse(exchange, listDefinedData(offset, limit)); - } else if ("PUT".equals(exchange.getRequestMethod())) { - Map params = parsePostParams(exchange); - boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); - JsonObject response = new JsonObject(); - response.addProperty("success", success); - response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); - response.addProperty("timestamp", System.currentTimeMillis()); - response.addProperty("port", this.port); - if (!success) { - exchange.sendResponseHeaders(400, 0); - } - sendJsonResponse(exchange, response); + } else if ("POST".equals(exchange.getRequestMethod())) { // Change PUT to POST + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); // Expect camelCase + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + + Gson gson = new Gson(); + String json = gson.toJson(response); + byte[] bytes = json.getBytes(StandardCharsets.UTF_8); + + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); + exchange.sendResponseHeaders(success ? 200 : 400, bytes.length); + + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + os.flush(); + } } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } @@ -333,15 +370,287 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { // Instance management endpoints server.createContext("/instances", exchange -> { - StringBuilder sb = new StringBuilder(); + List> instances = new ArrayList<>(); for (Map.Entry entry : activeInstances.entrySet()) { - sb.append(entry.getKey()).append(": ") - .append(entry.getValue().isBaseInstance ? "base" : "secondary") - .append("\n"); + Map instance = new HashMap<>(); + instance.put("port", entry.getKey().toString()); + instance.put("type", entry.getValue().isBaseInstance ? "base" : "secondary"); + instances.add(instance); } - sendResponse(exchange, sb.toString()); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(instances)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); }); - + + // Add decompile function by address endpoint + server.createContext("/decompile_function", exchange -> { + if ("GET".equals(exchange.getRequestMethod())) { + Map qparams = parseQueryParams(exchange); + String address = qparams.get("address"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address funcAddr = program.getAddressFactory().getAddress(address); + Function func = program.getFunctionManager().getFunctionAt(funcAddr); + if (func == null) { + sendErrorResponse(exchange, 404, "No function at address " + address); + return; + } + + DecompInterface decomp = new DecompInterface(); + try { + if (!decomp.openProgram(program)) { + sendErrorResponse(exchange, 500, "Failed to initialize decompiler"); + return; + } + + DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); + if (result == null || !result.decompileCompleted()) { + sendErrorResponse(exchange, 500, "Decompilation failed"); + return; + } + + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.addProperty("result", result.getDecompiledFunction().getC()); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } finally { + decomp.dispose(); + } + } catch (Exception e) { + Msg.error(this, "Error decompiling function", e); + sendErrorResponse(exchange, 500, "Error decompiling function: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add decompiler comment endpoint (Using POST now as per bridge) + server.createContext("/set_decompiler_comment", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("address"); + String comment = params.get("comment"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address addr = program.getAddressFactory().getAddress(address); + boolean success = setDecompilerComment(addr, comment); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Comment set successfully" : "Failed to set comment"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error setting decompiler comment", e); + sendErrorResponse(exchange, 500, "Error setting comment: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add disassembly comment endpoint (Using POST now as per bridge) + server.createContext("/set_disassembly_comment", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("address"); + String comment = params.get("comment"); + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "Address parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address addr = program.getAddressFactory().getAddress(address); + boolean success = setDisassemblyComment(addr, comment); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Comment set successfully" : "Failed to set comment"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error setting disassembly comment", e); + sendErrorResponse(exchange, 500, "Error setting comment: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add rename function by address endpoint (Using POST now as per bridge) + server.createContext("/rename_function_by_address", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); // Use specific JSON parser + String address = params.get("functionAddress"); // Expect camelCase + String newName = params.get("newName"); // Expect camelCase + + if (address == null || address.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); + return; + } + + if (newName == null || newName.isEmpty()) { + sendErrorResponse(exchange, 400, "newName parameter is required"); + return; + } + + Program program = getCurrentProgram(); + if (program == null) { + sendErrorResponse(exchange, 400, "No program loaded"); + return; + } + + try { + Address funcAddr = program.getAddressFactory().getAddress(address); + boolean success = renameFunctionByAddress(funcAddr, newName); + + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Function renamed successfully" : "Failed to rename function"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + } catch (Exception e) { + Msg.error(this, "Error renaming function", e); + sendErrorResponse(exchange, 500, "Error renaming function: " + e.getMessage()); + } + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + // Removed duplicate else block here + }); + + // Add rename local variable endpoint (Using POST now as per bridge) + server.createContext("/rename_local_variable", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String oldName = params.get("oldName"); + String newName = params.get("newName"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (oldName == null || oldName.isEmpty()) { + sendErrorResponse(exchange, 400, "oldName parameter is required"); return; + } + if (newName == null || newName.isEmpty()) { + sendErrorResponse(exchange, 400, "newName parameter is required"); return; + } + + // Call the existing renameVariable logic (needs adjustment for address) + // For now, just return success/failure based on parameters + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Rename local variable (not fully implemented by address yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add set function prototype endpoint (Using POST now as per bridge) + server.createContext("/set_function_prototype", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String prototype = params.get("prototype"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (prototype == null || prototype.isEmpty()) { + sendErrorResponse(exchange, 400, "prototype parameter is required"); return; + } + + // Call logic to set prototype (needs implementation) + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Set function prototype (not fully implemented yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Add set local variable type endpoint (Using POST now as per bridge) + server.createContext("/set_local_variable_type", exchange -> { + if ("POST".equals(exchange.getRequestMethod())) { + Map params = parseJsonPostParams(exchange); + String functionAddress = params.get("functionAddress"); + String variableName = params.get("variableName"); + String newType = params.get("newType"); + + if (functionAddress == null || functionAddress.isEmpty()) { + sendErrorResponse(exchange, 400, "functionAddress parameter is required"); return; + } + if (variableName == null || variableName.isEmpty()) { + sendErrorResponse(exchange, 400, "variableName parameter is required"); return; + } + if (newType == null || newType.isEmpty()) { + sendErrorResponse(exchange, 400, "newType parameter is required"); return; + } + + // Call logic to set variable type (needs implementation) + JsonObject response = new JsonObject(); + response.addProperty("success", true); // Placeholder + response.addProperty("message", "Set local variable type (not fully implemented yet)"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + sendJsonResponse(exchange, response); + + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed + } + }); + + // Super simple info endpoint with guaranteed response server.createContext("/info", exchange -> { try { @@ -380,10 +689,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { String error = "{\"error\": \"Internal error\", \"port\": " + port + "}"; byte[] bytes = error.getBytes(StandardCharsets.UTF_8); exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + // For mutation operations, set Content-Length explicitly to avoid chunked encoding + exchange.getResponseHeaders().set("Content-Length", String.valueOf(bytes.length)); exchange.sendResponseHeaders(200, bytes.length); - try (OutputStream os = exchange.getResponseBody()) { - os.write(bytes); - } + OutputStream os = exchange.getResponseBody(); + os.write(bytes); + os.close(); } catch (IOException ioe) { Msg.error(this, "Failed to send error response", ioe); } @@ -447,7 +758,7 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { }); server.createContext("/registerInstance", exchange -> { - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use JSON parser int port = parseIntOrDefault(params.get("port"), 0); if (port > 0) { sendResponse(exchange, "Instance registered on port " + port); @@ -457,7 +768,7 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { }); server.createContext("/unregisterInstance", exchange -> { - Map params = parsePostParams(exchange); + Map params = parseJsonPostParams(exchange); // Use JSON parser int port = parseIntOrDefault(params.get("port"), 0); if (port > 0 && activeInstances.containsKey(port)) { activeInstances.remove(port); @@ -566,36 +877,71 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { private String listImports(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; - - List lines = new ArrayList<>(); - for (Symbol symbol : program.getSymbolTable().getExternalSymbols()) { - lines.add(symbol.getName() + " -> " + symbol.getAddress()); + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; } - return paginateList(lines, offset, limit); + + List> imports = new ArrayList<>(); + for (Symbol symbol : program.getSymbolTable().getExternalSymbols()) { + Map imp = new HashMap<>(); + imp.put("name", symbol.getName()); + imp.put("address", symbol.getAddress().toString()); + imports.add(imp); + } + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(imports.size(), offset + limit); + List> paginated = imports.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listExports(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } + List> exports = new ArrayList<>(); SymbolTable table = program.getSymbolTable(); SymbolIterator it = table.getAllSymbols(true); - List lines = new ArrayList<>(); while (it.hasNext()) { Symbol s = it.next(); - // On older Ghidra, "export" is recognized via isExternalEntryPoint() if (s.isExternalEntryPoint()) { - lines.add(s.getName() + " -> " + s.getAddress()); + Map exp = new HashMap<>(); + exp.put("name", s.getName()); + exp.put("address", s.getAddress().toString()); + exports.add(exp); } } - return paginateList(lines, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(exports.size(), offset + limit); + List> paginated = exports.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listNamespaces(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } Set namespaces = new HashSet<>(); for (Symbol symbol : program.getSymbolTable().getAllSymbols(true)) { @@ -604,32 +950,57 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { namespaces.add(ns.getName()); } } + List sorted = new ArrayList<>(namespaces); Collections.sort(sorted); - return paginateList(sorted, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(sorted.size(), offset + limit); + List paginated = sorted.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listDefinedData(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + return "{\"success\":false,\"error\":\"No program loaded\"}"; + } - List lines = new ArrayList<>(); + List> dataItems = new ArrayList<>(); for (MemoryBlock block : program.getMemory().getBlocks()) { DataIterator it = program.getListing().getDefinedData(block.getStart(), true); while (it.hasNext()) { Data data = it.next(); if (block.contains(data.getAddress())) { - String label = data.getLabel() != null ? data.getLabel() : "(unnamed)"; - String valRepr = data.getDefaultValueRepresentation(); - lines.add(String.format("%s: %s = %s", - data.getAddress(), - escapeNonAscii(label), - escapeNonAscii(valRepr) - )); + Map item = new HashMap<>(); + item.put("address", data.getAddress().toString()); + item.put("label", data.getLabel() != null ? data.getLabel() : "(unnamed)"); + item.put("value", data.getDefaultValueRepresentation()); + dataItems.add(item); } } } - return paginateList(lines, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(dataItems.size(), offset + limit); + List> paginated = dataItems.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String searchFunctionsByName(String searchTerm, int offset, int limit) { @@ -673,15 +1044,113 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { if (result != null && result.decompileCompleted()) { return result.getDecompiledFunction().getC(); } - return "Decompilation failed"; + return "Decompilation failed"; // Keep as string for now, handled by sendResponse } } - return "Function not found"; + // Return specific error object instead of just a string + JsonObject errorResponse = new JsonObject(); + errorResponse.addProperty("success", false); + errorResponse.addProperty("error", "Function not found: " + name); + return errorResponse.toString(); // Return JSON string } finally { decomp.dispose(); } } + private boolean renameFunctionByAddress(Address functionAddress, String newName) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Rename function via HTTP"); + try { + Function func = program.getFunctionManager().getFunctionAt(functionAddress); + if (func != null) { + func.setName(newName, SourceType.USER_DEFINED); + successFlag.set(true); + } + } + catch (Exception e) { + Msg.error(this, "Error renaming function", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute rename on Swing thread", e); + } + return successFlag.get(); + } + + private boolean setDecompilerComment(Address address, String comment) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Set decompiler comment"); + try { + DecompInterface decomp = new DecompInterface(); + decomp.openProgram(program); + + Function func = program.getFunctionManager().getFunctionContaining(address); + if (func != null) { + DecompileResults results = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); + if (results != null && results.decompileCompleted()) { + HighFunction highFunc = results.getHighFunction(); + if (highFunc != null) { + program.getListing().setComment(address, CodeUnit.PRE_COMMENT, comment); + successFlag.set(true); + } + } + } + } + catch (Exception e) { + Msg.error(this, "Error setting decompiler comment", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute set comment on Swing thread", e); + } + return successFlag.get(); + } + + private boolean setDisassemblyComment(Address address, String comment) { + Program program = getCurrentProgram(); + if (program == null) return false; + + AtomicBoolean successFlag = new AtomicBoolean(false); + try { + SwingUtilities.invokeAndWait(() -> { + int tx = program.startTransaction("Set disassembly comment"); + try { + Listing listing = program.getListing(); + listing.setComment(address, CodeUnit.EOL_COMMENT, comment); + successFlag.set(true); + } + catch (Exception e) { + Msg.error(this, "Error setting disassembly comment", e); + } + finally { + program.endTransaction(tx, successFlag.get()); + } + }); + } + catch (InterruptedException | InvocationTargetException e) { + Msg.error(this, "Failed to execute set comment on Swing thread", e); + } + return successFlag.get(); + } + private boolean renameFunction(String oldName, String newName) { Program program = getCurrentProgram(); if (program == null) return false; @@ -1189,62 +1658,39 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } /** - * Parse post body params from form data or simple JSON + * Parse post body params strictly as JSON. */ - private Map parsePostParams(HttpExchange exchange) throws IOException { + private Map parseJsonPostParams(HttpExchange exchange) throws IOException { byte[] body = exchange.getRequestBody().readAllBytes(); String bodyStr = new String(body, StandardCharsets.UTF_8); Map params = new HashMap<>(); - - // Check if it looks like JSON - if (bodyStr.trim().startsWith("{")) { - try { - // Manual simple JSON parsing for key-value pairs - // This avoids using the JSONParser which might be causing issues - String jsonContent = bodyStr.trim(); - // Remove the outer braces - jsonContent = jsonContent.substring(1, jsonContent.length() - 1).trim(); - - // Split by commas not inside quotes - String[] pairs = jsonContent.split(",(?=([^\"]*\"[^\"]*\")*[^\"]*$)"); - - for (String pair : pairs) { - String[] keyValue = pair.split(":", 2); - if (keyValue.length == 2) { - String key = keyValue[0].trim(); - String value = keyValue[1].trim(); - - // Remove quotes if present - if (key.startsWith("\"") && key.endsWith("\"")) { - key = key.substring(1, key.length() - 1); - } - - if (value.startsWith("\"") && value.endsWith("\"")) { - value = value.substring(1, value.length() - 1); - } - - params.put(key, value); - } + + try { + // Use Gson to properly parse JSON + Gson gson = new Gson(); + JsonObject json = gson.fromJson(bodyStr, JsonObject.class); + + for (Map.Entry entry : json.entrySet()) { + String key = entry.getKey(); + JsonElement value = entry.getValue(); + + if (value.isJsonPrimitive()) { + params.put(key, value.getAsString()); + } else { + // Optionally handle non-primitive types if needed, otherwise stringify + params.put(key, value.toString()); } - - return params; - } catch (Exception e) { - Msg.error(this, "Failed to parse JSON request body: " + e.getMessage(), e); - // Fall through to form data parsing } + } catch (Exception e) { + Msg.error(this, "Failed to parse JSON request body: " + e.getMessage(), e); + // Throw an exception or return an empty map to indicate failure + throw new IOException("Invalid JSON request body: " + e.getMessage(), e); } - - // If JSON parsing fails or it's not JSON, try form data - for (String pair : bodyStr.split("&")) { - String[] kv = pair.split("="); - if (kv.length == 2) { - params.put(kv[0], kv[1]); - } - } - return params; } + + /** * Convert a list of strings into one big newline-delimited string, applying offset & limit. */ diff --git a/test_mcp_client.py b/test_mcp_client.py index 84d2a4a..d75ef3e 100644 --- a/test_mcp_client.py +++ b/test_mcp_client.py @@ -84,61 +84,28 @@ async def test_bridge(): if not func_list: logger.warning("No functions in result - skipping mutating tests") return - - # Get first function's name - func_name = func_list[0].get("name", "") - if not func_name: - logger.warning("No function name found - skipping mutating tests") - return - - # Get full function details - func_details = await session.call_tool( - "get_function", - arguments={"port": 8192, "name": func_name} - ) - - if not func_details.content or not func_details.content[0].text: - logger.warning("Could not get function details - skipping mutating tests") - return - - # Parse function details - response is the decompiled code text - func_text = func_details.content[0].text - if not func_text: - logger.warning("Empty function details - skipping mutating tests") - return - - # First line contains name and address - first_line = func_text.split('\n')[0] - if not first_line: - logger.warning("Invalid function format - skipping mutating tests") - return - - # Extract name and address from first line - parts = first_line.split() - if len(parts) < 2: - logger.warning("Could not parse function details - skipping mutating tests") - return - - func_name = parts[1] # Second part is function name - func_address = parts[0] # First part is address - + + # Get first function's name and address directly from list_functions result + first_func = func_list[0] + func_name = first_func.get("name", "") + func_address = first_func.get("address", "") # Get address directly + if not func_name or not func_address: - logger.warning("Could not get valid function name/address - skipping mutating tests") + logger.warning("No function name/address found in list_functions result - skipping mutating tests") return - + except json.JSONDecodeError as e: - logger.warning(f"Error parsing function data: {e} - skipping mutating tests") + logger.warning(f"Error parsing list_functions data: {e} - skipping mutating tests") return # Test function renaming original_name = func_name test_name = f"{func_name}_test" - + # Test successful rename operations - rename_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": original_name, "new_name": test_name} - ) + rename_args = {"port": 8192, "name": original_name, "new_name": test_name} + logger.info(f"Calling update_function with args: {rename_args}") + rename_result = await session.call_tool("update_function", arguments=rename_args) rename_data = json.loads(rename_result.content[0].text) assert rename_data.get("success") is True, f"Rename failed: {rename_data}" logger.info(f"Rename result: {rename_result}") @@ -153,10 +120,9 @@ async def test_bridge(): logger.info(f"Renamed function result: {renamed_func}") # Rename back to original - revert_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": test_name, "new_name": original_name} - ) + revert_args = {"port": 8192, "name": test_name, "new_name": original_name} + logger.info(f"Calling update_function with args: {revert_args}") + revert_result = await session.call_tool("update_function", arguments=revert_args) revert_data = json.loads(revert_result.content[0].text) assert revert_data.get("success") is True, f"Revert rename failed: {revert_data}" logger.info(f"Revert rename result: {revert_result}") @@ -172,57 +138,43 @@ async def test_bridge(): # Test successful comment operations test_comment = "Test comment from MCP client" - comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": func_address, - "comment": test_comment - } - ) + comment_args = {"port": 8192, "address": func_address, "comment": test_comment} + logger.info(f"Calling set_decompiler_comment with args: {comment_args}") + comment_result = await session.call_tool("set_decompiler_comment", arguments=comment_args) comment_data = json.loads(comment_result.content[0].text) assert comment_data.get("success") is True, f"Add comment failed: {comment_data}" logger.info(f"Add comment result: {comment_result}") # Remove comment - remove_comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": func_address, - "comment": "" - } - ) + remove_comment_args = {"port": 8192, "address": func_address, "comment": ""} + logger.info(f"Calling set_decompiler_comment with args: {remove_comment_args}") + remove_comment_result = await session.call_tool("set_decompiler_comment", arguments=remove_comment_args) remove_data = json.loads(remove_comment_result.content[0].text) assert remove_data.get("success") is True, f"Remove comment failed: {remove_data}" logger.info(f"Remove comment result: {remove_comment_result}") # Test expected failure cases # Try to rename non-existent function - bad_rename_result = await session.call_tool( - "update_function", - arguments={"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"} - ) + bad_rename_args = {"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"} + logger.info(f"Calling update_function with args: {bad_rename_args}") + bad_rename_result = await session.call_tool("update_function", arguments=bad_rename_args) + logger.info(f"Bad rename result: {bad_rename_result}") # Log the response bad_rename_data = json.loads(bad_rename_result.content[0].text) - assert bad_rename_data.get("success") is False, "Renaming non-existent function should fail" - + assert bad_rename_data.get("success") is False, f"Renaming non-existent function should fail, but got: {bad_rename_data}" + # Try to get non-existent function bad_get_result = await session.call_tool( "get_function", arguments={"port": 8192, "name": "nonexistent_function"} ) + logger.info(f"Bad get result: {bad_get_result}") # Log the response bad_get_data = json.loads(bad_get_result.content[0].text) - assert bad_get_data.get("success") is False, "Getting non-existent function should fail" - + assert bad_get_data.get("success") is False, f"Getting non-existent function should fail, but got: {bad_get_data}" + # Try to comment on invalid address - bad_comment_result = await session.call_tool( - "set_decompiler_comment", - arguments={ - "port": 8192, - "address": "0xinvalid", - "comment": "should fail" - } - ) + bad_comment_args = {"port": 8192, "address": "0xinvalid", "comment": "should fail"} + logger.info(f"Calling set_decompiler_comment with args: {bad_comment_args}") + bad_comment_result = await session.call_tool("set_decompiler_comment", arguments=bad_comment_args) bad_comment_data = json.loads(bad_comment_result.content[0].text) assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail"