From 14eae14f63d94b2e168760d743ecc440e8c4d132 Mon Sep 17 00:00:00 2001 From: Teal Bauer Date: Fri, 4 Apr 2025 18:10:45 +0200 Subject: [PATCH] Switch all results over to JSON --- .../eu/starsong/ghidra/GhydraMCPPlugin.java | 368 ++++++++++-------- test_http_api.py | 37 +- test_mcp_client.py | 127 ++++-- 3 files changed, 335 insertions(+), 197 deletions(-) diff --git a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java index 5c2c19d..e11e8d1 100644 --- a/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java +++ b/src/main/java/eu/starsong/ghidra/GhydraMCPPlugin.java @@ -1,70 +1,66 @@ package eu.starsong.ghidra; -import ghidra.framework.plugintool.*; -import ghidra.framework.main.ApplicationLevelPlugin; -import ghidra.program.model.address.Address; -import ghidra.program.model.address.GlobalNamespace; -import ghidra.program.model.data.DataType; -import ghidra.program.model.data.DataTypeManager; -import ghidra.program.model.listing.*; -import ghidra.program.model.mem.MemoryBlock; -import ghidra.program.model.pcode.HighVariable; -import ghidra.program.model.pcode.HighSymbol; -import ghidra.program.model.pcode.VarnodeAST; -import ghidra.program.model.pcode.HighFunction; -import ghidra.program.model.pcode.HighFunctionDBUtil; -import ghidra.program.model.pcode.LocalSymbolMap; -import ghidra.program.model.pcode.HighFunctionDBUtil.ReturnCommitOption; -import ghidra.program.model.symbol.*; -import ghidra.app.decompiler.DecompInterface; -import ghidra.app.decompiler.DecompileResults; -import ghidra.app.decompiler.ClangNode; -import ghidra.app.decompiler.ClangTokenGroup; -import ghidra.app.decompiler.ClangVariableToken; -import ghidra.app.plugin.PluginCategoryNames; -import ghidra.app.services.ProgramManager; -import ghidra.app.util.demangler.DemanglerUtil; -import ghidra.framework.model.Project; -import ghidra.framework.model.DomainFile; -import ghidra.framework.plugintool.PluginInfo; -import ghidra.framework.plugintool.util.PluginStatus; -import ghidra.util.Msg; -import ghidra.util.task.ConsoleTaskMonitor; - -import com.sun.net.httpserver.HttpExchange; -import com.sun.net.httpserver.HttpServer; - -import javax.swing.SwingUtilities; import java.io.IOException; import java.io.OutputStream; import java.lang.reflect.InvocationTargetException; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.nio.charset.StandardCharsets; -import java.util.*; -import java.util.concurrent.*; -import java.util.concurrent.atomic.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.swing.SwingUtilities; // For JSON response handling import com.google.gson.Gson; import com.google.gson.JsonObject; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpServer; -import ghidra.app.services.CodeViewerService; -import ghidra.app.util.PseudoDisassembler; -import ghidra.app.cmd.function.SetVariableNameCmd; +import ghidra.app.decompiler.DecompInterface; +import ghidra.app.decompiler.DecompileResults; +import ghidra.app.plugin.PluginCategoryNames; +import ghidra.app.services.ProgramManager; +import ghidra.framework.main.ApplicationLevelPlugin; +import ghidra.framework.model.Project; +import ghidra.framework.plugintool.Plugin; +import ghidra.framework.plugintool.PluginInfo; +import ghidra.framework.plugintool.PluginTool; +import ghidra.framework.plugintool.util.PluginStatus; +import ghidra.program.model.address.Address; +import ghidra.program.model.address.GlobalNamespace; +import ghidra.program.model.data.DataType; +import ghidra.program.model.data.DataTypeManager; +import ghidra.program.model.listing.Data; +import ghidra.program.model.listing.DataIterator; +import ghidra.program.model.listing.Function; +import ghidra.program.model.listing.Listing; +import ghidra.program.model.listing.Parameter; +import ghidra.program.model.listing.Program; +import ghidra.program.model.listing.VariableStorage; +import ghidra.program.model.mem.MemoryBlock; +import ghidra.program.model.pcode.HighFunction; +import ghidra.program.model.pcode.HighFunctionDBUtil; +import ghidra.program.model.pcode.HighFunctionDBUtil.ReturnCommitOption; +import ghidra.program.model.pcode.HighSymbol; +import ghidra.program.model.pcode.LocalSymbolMap; +import ghidra.program.model.symbol.Namespace; import ghidra.program.model.symbol.SourceType; -import ghidra.program.model.listing.LocalVariableImpl; -import ghidra.program.model.listing.ParameterImpl; -import ghidra.util.exception.DuplicateNameException; -import ghidra.util.exception.InvalidInputException; -import ghidra.program.util.ProgramLocation; -import ghidra.util.task.TaskMonitor; -import ghidra.program.model.pcode.Varnode; -import ghidra.program.model.data.PointerDataType; -import ghidra.program.model.data.Undefined1DataType; -import ghidra.program.model.listing.Variable; -import ghidra.app.decompiler.component.DecompilerUtils; -import ghidra.app.decompiler.ClangToken; +import ghidra.program.model.symbol.Symbol; +import ghidra.program.model.symbol.SymbolIterator; +import ghidra.program.model.symbol.SymbolTable; +import ghidra.program.model.symbol.SymbolType; +import ghidra.util.Msg; +import ghidra.util.task.ConsoleTaskMonitor; @PluginInfo( status = PluginStatus.RELEASED, @@ -97,8 +93,6 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } } - Msg.info(this, "Marker"); - // Log to both console and log file Msg.info(this, "GhydraMCPPlugin loaded on port " + port); System.out.println("[GhydraMCP] Plugin loaded on port " + port); @@ -176,12 +170,28 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { Map params = parsePostParams(exchange); if (params.containsKey("newName")) { // Rename variable - String result = renameVariable(functionName, variableName, params.get("newName")); - sendResponse(exchange, result); + boolean success = renameVariable(functionName, variableName, params.get("newName")); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + if (!success) { + exchange.sendResponseHeaders(400, 0); + } + sendJsonResponse(exchange, response); } else if (params.containsKey("dataType")) { // Retype variable - String result = retypeVariable(functionName, variableName, params.get("dataType")); - sendResponse(exchange, result); + boolean success = retypeVariable(functionName, variableName, params.get("dataType")); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + if (!success) { + exchange.sendResponseHeaders(400, 0); + } + sendJsonResponse(exchange, response); } else { sendResponse(exchange, "Missing required parameter: newName or dataType"); } @@ -195,61 +205,40 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } else if ("PUT".equals(exchange.getRequestMethod())) { Map params = parsePostParams(exchange); String newName = params.get("newName"); - String response = renameFunction(functionName, newName) - ? "Renamed successfully" : "Rename failed"; - sendResponse(exchange, response); + boolean success = renameFunction(functionName, newName); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Renamed successfully" : "Rename failed"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + if (!success) { + exchange.sendResponseHeaders(400, 0); + } + sendJsonResponse(exchange, response); } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } } }); - // Class resources with detailed logging + // Class resources server.createContext("/classes", exchange -> { - try { - if ("GET".equals(exchange.getRequestMethod())) { + if ("GET".equals(exchange.getRequestMethod())) { + try { + Map qparams = parseQueryParams(exchange); + int offset = parseIntOrDefault(qparams.get("offset"), 0); + int limit = parseIntOrDefault(qparams.get("limit"), 100); + sendJsonResponse(exchange, getAllClassNames(offset, limit)); + } catch (Exception e) { + Msg.error(this, "/classes: Error in request processing", e); try { - Map qparams = parseQueryParams(exchange); - int offset = parseIntOrDefault(qparams.get("offset"), 0); - int limit = parseIntOrDefault(qparams.get("limit"), 100); - - String result = getAllClassNames(offset, limit); - - JsonObject json = new JsonObject(); - json.addProperty("success", true); - json.addProperty("result", result); - json.addProperty("timestamp", System.currentTimeMillis()); - json.addProperty("port", this.port); - - Gson gson = new Gson(); - String jsonStr = gson.toJson(json); - - exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); - - byte[] bytes = jsonStr.getBytes(StandardCharsets.UTF_8); - exchange.sendResponseHeaders(200, bytes.length); - - OutputStream os = exchange.getResponseBody(); - - os.write(bytes); - - os.flush(); - - os.close(); - - } catch (Exception e) { - Msg.error(this, "/classes: Error in request processing: " + e.getMessage(), e); - try { - sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage()); - } catch (IOException ioe) { - Msg.error(this, "/classes: Failed to send error response: " + ioe.getMessage(), ioe); - } + sendErrorResponse(exchange, 500, "Internal server error"); + } catch (IOException ioe) { + Msg.error(this, "/classes: Failed to send error response", ioe); } - } else { - exchange.sendResponseHeaders(405, -1); // Method Not Allowed } - } catch (Exception e) { - Msg.error(this, "/classes: Unhandled error: " + e.getMessage(), e); + } else { + exchange.sendResponseHeaders(405, -1); // Method Not Allowed } }); @@ -309,8 +298,16 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { sendResponse(exchange, listDefinedData(offset, limit)); } else if ("PUT".equals(exchange.getRequestMethod())) { Map params = parsePostParams(exchange); - renameDataAtAddress(params.get("address"), params.get("newName")); - sendResponse(exchange, "Rename data attempted"); + boolean success = renameDataAtAddress(params.get("address"), params.get("newName")); + JsonObject response = new JsonObject(); + response.addProperty("success", success); + response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data"); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + if (!success) { + exchange.sendResponseHeaders(400, 0); + } + sendJsonResponse(exchange, response); } else { exchange.sendResponseHeaders(405, -1); // Method Not Allowed } @@ -484,18 +481,38 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { private String getAllFunctionNames(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) return "{\"success\":false,\"error\":\"No program loaded\"}"; - List names = new ArrayList<>(); + List> functions = new ArrayList<>(); for (Function f : program.getFunctionManager().getFunctions(true)) { - names.add(f.getName() + " @ " + f.getEntryPoint()); + Map func = new HashMap<>(); + func.put("name", f.getName()); + func.put("address", f.getEntryPoint().toString()); + functions.add(func); } - return paginateList(names, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(functions.size(), offset + limit); + List> paginated = functions.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } - private String getAllClassNames(int offset, int limit) { + private JsonObject getAllClassNames(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) { + JsonObject error = new JsonObject(); + error.addProperty("success", false); + error.addProperty("error", "No program loaded"); + return error; + } Set classNames = new HashSet<>(); for (Symbol symbol : program.getSymbolTable().getAllSymbols(true)) { @@ -504,21 +521,47 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { classNames.add(ns.getName()); } } - // Convert set to list for pagination + + // Convert to sorted list and paginate List sorted = new ArrayList<>(classNames); Collections.sort(sorted); - return paginateList(sorted, offset, limit); + int start = Math.max(0, offset); + int end = Math.min(sorted.size(), offset + limit); + List paginated = sorted.subList(start, end); + + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", new Gson().toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return response; } private String listSegments(int offset, int limit) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) return "{\"success\":false,\"error\":\"No program loaded\"}"; - List lines = new ArrayList<>(); + List> segments = new ArrayList<>(); for (MemoryBlock block : program.getMemory().getBlocks()) { - lines.add(String.format("%s: %s - %s", block.getName(), block.getStart(), block.getEnd())); + Map seg = new HashMap<>(); + seg.put("name", block.getName()); + seg.put("start", block.getStart().toString()); + seg.put("end", block.getEnd().toString()); + segments.add(seg); } - return paginateList(lines, offset, limit); + + // Apply pagination + int start = Math.max(0, offset); + int end = Math.min(segments.size(), offset + limit); + List> paginated = segments.subList(start, end); + + Gson gson = new Gson(); + JsonObject response = new JsonObject(); + response.addProperty("success", true); + response.add("result", gson.toJsonTree(paginated)); + response.addProperty("timestamp", System.currentTimeMillis()); + response.addProperty("port", this.port); + return gson.toJson(response); } private String listImports(int offset, int limit) { @@ -670,10 +713,11 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { return successFlag.get(); } - private void renameDataAtAddress(String addressStr, String newName) { + private boolean renameDataAtAddress(String addressStr, String newName) { Program program = getCurrentProgram(); - if (program == null) return; + if (program == null) return false; + AtomicBoolean successFlag = new AtomicBoolean(false); try { SwingUtilities.invokeAndWait(() -> { int tx = program.startTransaction("Rename data"); @@ -686,8 +730,10 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { Symbol symbol = symTable.getPrimarySymbol(addr); if (symbol != null) { symbol.setName(newName, SourceType.USER_DEFINED); + successFlag.set(true); } else { symTable.createLabel(addr, newName, SourceType.USER_DEFINED); + successFlag.set(true); } } } @@ -695,13 +741,14 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { Msg.error(this, "Rename data error", e); } finally { - program.endTransaction(tx, true); + program.endTransaction(tx, successFlag.get()); } }); } catch (InterruptedException | InvocationTargetException e) { Msg.error(this, "Failed to execute rename data on Swing thread", e); } + return successFlag.get(); } // ---------------------------------------------------------------------------------- @@ -789,9 +836,9 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } } - private String renameVariable(String functionName, String oldName, String newName) { + private boolean renameVariable(String functionName, String oldName, String newName) { Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) return false; DecompInterface decomp = new DecompInterface(); decomp.openProgram(program); @@ -805,22 +852,22 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } if (func == null) { - return "Function not found"; + return false; } DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); if (result == null || !result.decompileCompleted()) { - return "Decompilation failed"; + return false; } HighFunction highFunction = result.getHighFunction(); if (highFunction == null) { - return "Decompilation failed (no high function)"; + return false; } LocalSymbolMap localSymbolMap = highFunction.getLocalSymbolMap(); if (localSymbolMap == null) { - return "Decompilation failed (no local symbol map)"; + return false; } HighSymbol highSymbol = null; @@ -833,12 +880,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { highSymbol = symbol; } if (symbolName.equals(newName)) { - return "Error: A variable with name '" + newName + "' already exists in this function"; + return false; } } if (highSymbol == null) { - return "Variable not found"; + return false; } boolean commitRequired = checkFullCommit(highSymbol, highFunction); @@ -871,11 +918,10 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } }); } catch (InterruptedException | InvocationTargetException e) { - String errorMsg = "Failed to execute rename on Swing thread: " + e.getMessage(); - Msg.error(this, errorMsg, e); - return errorMsg; + Msg.error(this, "Failed to execute rename on Swing thread", e); + return false; } - return successFlag.get() ? "Variable renamed" : "Failed to rename variable"; + return successFlag.get(); } /** @@ -914,15 +960,15 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { return false; } - private String retypeVariable(String functionName, String varName, String dataTypeName) { + private boolean retypeVariable(String functionName, String varName, String dataTypeName) { if (varName == null || varName.isEmpty() || dataTypeName == null || dataTypeName.isEmpty()) { - return "Both variable name and data type are required"; + return false; } Program program = getCurrentProgram(); - if (program == null) return "No program loaded"; + if (program == null) return false; - AtomicReference result = new AtomicReference<>("Variable retype failed"); + AtomicBoolean result = new AtomicBoolean(false); try { SwingUtilities.invokeAndWait(() -> { @@ -930,7 +976,6 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { try { Function function = findFunctionByName(program, functionName); if (function == null) { - result.set("Function not found: " + functionName); return; } @@ -940,13 +985,11 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { DecompileResults decompRes = decomp.decompileFunction(function, 30, new ConsoleTaskMonitor()); if (decompRes == null || !decompRes.decompileCompleted()) { - result.set("Failed to decompile function: " + functionName); return; } HighFunction highFunction = decompRes.getHighFunction(); if (highFunction == null) { - result.set("Failed to get high function"); return; } @@ -963,14 +1006,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { } if (targetSymbol == null) { - result.set("Variable not found: " + varName); return; } // Find the data type by name DataType dataType = findDataType(program, dataTypeName); if (dataType == null) { - result.set("Data type not found: " + dataTypeName); return; } @@ -978,17 +1019,17 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { HighFunctionDBUtil.updateDBVariable(targetSymbol, targetSymbol.getName(), dataType, SourceType.USER_DEFINED); - result.set("Variable '" + varName + "' retyped to '" + dataTypeName + "'"); + result.set(true); } catch (Exception e) { Msg.error(this, "Error retyping variable", e); - result.set("Error: " + e.getMessage()); + result.set(false); } finally { program.endTransaction(tx, true); } }); } catch (InterruptedException | InvocationTargetException e) { Msg.error(this, "Failed to execute on Swing thread", e); - result.set("Error: " + e.getMessage()); + result.set(false); } return result.get(); @@ -1277,21 +1318,32 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin { private void sendResponse(HttpExchange exchange, Object response) throws IOException { - JsonObject json = new JsonObject(); - json.addProperty("success", true); - if (response instanceof String) { - json.addProperty("result", (String)response); + if (response instanceof String && ((String)response).startsWith("{")) { + // Already JSON formatted, send as-is + byte[] bytes = ((String)response).getBytes(StandardCharsets.UTF_8); + exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8"); + exchange.sendResponseHeaders(200, bytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } } else { - json.addProperty("data", response.toString()); + // Wrap in standard response format + JsonObject json = new JsonObject(); + json.addProperty("success", true); + if (response instanceof String) { + json.addProperty("result", (String)response); + } else { + json.add("result", new Gson().toJsonTree(response)); + } + json.addProperty("timestamp", System.currentTimeMillis()); + json.addProperty("port", this.port); + if (this.isBaseInstance) { + json.addProperty("instanceType", "base"); + } else { + json.addProperty("instanceType", "secondary"); + } + sendJsonResponse(exchange, json); } - json.addProperty("timestamp", System.currentTimeMillis()); - json.addProperty("port", this.port); - if (this.isBaseInstance) { - json.addProperty("instanceType", "base"); - } else { - json.addProperty("instanceType", "secondary"); - } - sendJsonResponse(exchange, json); } private void sendJsonResponse(HttpExchange exchange, JsonObject jsonObj) throws IOException { diff --git a/test_http_api.py b/test_http_api.py index ba2a7f1..d69c723 100644 --- a/test_http_api.py +++ b/test_http_api.py @@ -84,8 +84,13 @@ class GhydraMCPHttpApiTests(unittest.TestCase): self.assertIn("timestamp", data) self.assertIn("port", data) - # Check that we have either result or data - self.assertTrue("result" in data or "data" in data) + # Check result is an array of function objects + self.assertIn("result", data) + self.assertIsInstance(data["result"], list) + if data["result"]: # If there are functions + func = data["result"][0] + self.assertIn("name", func) + self.assertIn("address", func) def test_functions_with_pagination(self): """Test the /functions endpoint with pagination""" @@ -100,10 +105,19 @@ class GhydraMCPHttpApiTests(unittest.TestCase): self.assertTrue(data["success"]) self.assertIn("timestamp", data) self.assertIn("port", data) + + # Check result is an array of max 5 function objects + self.assertIn("result", data) + self.assertIsInstance(data["result"], list) + self.assertLessEqual(len(data["result"]), 5) + if data["result"]: # If there are functions + func = data["result"][0] + self.assertIn("name", func) + self.assertIn("address", func) def test_classes_endpoint(self): """Test the /classes endpoint""" - response = requests.get(f"{BASE_URL}/classes") + response = requests.get(f"{BASE_URL}/classes?offset=0&limit=10") self.assertEqual(response.status_code, 200) # Verify response is valid JSON @@ -114,10 +128,16 @@ class GhydraMCPHttpApiTests(unittest.TestCase): self.assertTrue(data["success"]) self.assertIn("timestamp", data) self.assertIn("port", data) + + # Check result is an array of class names + self.assertIn("result", data) + self.assertIsInstance(data["result"], list) + if data["result"]: # If there are classes + self.assertIsInstance(data["result"][0], str) def test_segments_endpoint(self): """Test the /segments endpoint""" - response = requests.get(f"{BASE_URL}/segments") + response = requests.get(f"{BASE_URL}/segments?offset=0&limit=10") self.assertEqual(response.status_code, 200) # Verify response is valid JSON @@ -128,6 +148,15 @@ class GhydraMCPHttpApiTests(unittest.TestCase): self.assertTrue(data["success"]) self.assertIn("timestamp", data) self.assertIn("port", data) + + # Check result is an array of segment objects + self.assertIn("result", data) + self.assertIsInstance(data["result"], list) + if data["result"]: # If there are segments + seg = data["result"][0] + self.assertIn("name", seg) + self.assertIn("start", seg) + self.assertIn("end", seg) def test_variables_endpoint(self): """Test the /variables endpoint""" diff --git a/test_mcp_client.py b/test_mcp_client.py index 43ea318..84d2a4a 100644 --- a/test_mcp_client.py +++ b/test_mcp_client.py @@ -3,13 +3,12 @@ Test script for the GhydraMCP bridge using the MCP client. This script tests the bridge by sending MCP requests and handling responses. """ -import asyncio +import json import logging import sys from typing import Any import anyio - from mcp.client.session import ClientSession from mcp.client.stdio import StdioServerParameters, stdio_client @@ -69,60 +68,79 @@ async def test_bridge(): arguments={"port": 8192, "offset": 0, "limit": 5} ) - if not hasattr(list_funcs, "result") or not hasattr(list_funcs.result, "content") or not list_funcs.result.content: + if not list_funcs or not list_funcs.content: logger.warning("No functions found - skipping mutating tests") return - # The list_functions result contains a JSON string in the text field - func_json = list_funcs.result.content[0].get("text", "") - if not func_json: + # The list_functions result contains the function data directly + if not list_funcs.content: logger.warning("No function data found - skipping mutating tests") return + # Parse the JSON response try: - # Parse the JSON to get the function list - func_data = json.loads(func_json) - func_list = func_data.get("result", "").split("\n") + func_data = json.loads(list_funcs.content[0].text) + func_list = func_data.get("result", []) if not func_list: logger.warning("No functions in result - skipping mutating tests") return - # Extract first function name (format: "name @ address") - func_name = func_list[0].split("@")[0].strip() - except (json.JSONDecodeError, AttributeError) as e: + # Get first function's name + func_name = func_list[0].get("name", "") + if not func_name: + logger.warning("No function name found - skipping mutating tests") + return + + # Get full function details + func_details = await session.call_tool( + "get_function", + arguments={"port": 8192, "name": func_name} + ) + + if not func_details.content or not func_details.content[0].text: + logger.warning("Could not get function details - skipping mutating tests") + return + + # Parse function details - response is the decompiled code text + func_text = func_details.content[0].text + if not func_text: + logger.warning("Empty function details - skipping mutating tests") + return + + # First line contains name and address + first_line = func_text.split('\n')[0] + if not first_line: + logger.warning("Invalid function format - skipping mutating tests") + return + + # Extract name and address from first line + parts = first_line.split() + if len(parts) < 2: + logger.warning("Could not parse function details - skipping mutating tests") + return + + func_name = parts[1] # Second part is function name + func_address = parts[0] # First part is address + + if not func_name or not func_address: + logger.warning("Could not get valid function name/address - skipping mutating tests") + return + + except json.JSONDecodeError as e: logger.warning(f"Error parsing function data: {e} - skipping mutating tests") return - if not func_name: - logger.warning("Could not parse function name - skipping mutating tests") - return - - # Get full function details - func_details = await session.call_tool( - "get_function", - arguments={"port": 8192, "name": func_name} - ) - - if not hasattr(func_details, "result") or not hasattr(func_details.result, "content") or not func_details.result.content: - logger.warning("Could not get function details - skipping mutating tests") - return - - func_content = func_details.result.content[0] - func_name = func_content.get("text", "").split("\n")[0] - func_address = func_content.get("address", "") - - if not func_name or not func_address: - logger.warning("Could not get valid function name/address - skipping mutating tests") - return # Test function renaming original_name = func_name test_name = f"{func_name}_test" - # Rename to test name + # Test successful rename operations rename_result = await session.call_tool( "update_function", arguments={"port": 8192, "name": original_name, "new_name": test_name} ) + rename_data = json.loads(rename_result.content[0].text) + assert rename_data.get("success") is True, f"Rename failed: {rename_data}" logger.info(f"Rename result: {rename_result}") # Verify rename @@ -130,6 +148,8 @@ async def test_bridge(): "get_function", arguments={"port": 8192, "name": test_name} ) + renamed_data = json.loads(renamed_func.content[0].text) + assert renamed_data.get("success") is True, f"Get renamed function failed: {renamed_data}" logger.info(f"Renamed function result: {renamed_func}") # Rename back to original @@ -137,6 +157,8 @@ async def test_bridge(): "update_function", arguments={"port": 8192, "name": test_name, "new_name": original_name} ) + revert_data = json.loads(revert_result.content[0].text) + assert revert_data.get("success") is True, f"Revert rename failed: {revert_data}" logger.info(f"Revert rename result: {revert_result}") # Verify revert @@ -144,9 +166,11 @@ async def test_bridge(): "get_function", arguments={"port": 8192, "name": original_name} ) + original_data = json.loads(original_func.content[0].text) + assert original_data.get("success") is True, f"Get original function failed: {original_data}" logger.info(f"Original function result: {original_func}") - # Test adding/removing comment + # Test successful comment operations test_comment = "Test comment from MCP client" comment_result = await session.call_tool( "set_decompiler_comment", @@ -156,6 +180,8 @@ async def test_bridge(): "comment": test_comment } ) + comment_data = json.loads(comment_result.content[0].text) + assert comment_data.get("success") is True, f"Add comment failed: {comment_data}" logger.info(f"Add comment result: {comment_result}") # Remove comment @@ -167,8 +193,39 @@ async def test_bridge(): "comment": "" } ) + remove_data = json.loads(remove_comment_result.content[0].text) + assert remove_data.get("success") is True, f"Remove comment failed: {remove_data}" logger.info(f"Remove comment result: {remove_comment_result}") + # Test expected failure cases + # Try to rename non-existent function + bad_rename_result = await session.call_tool( + "update_function", + arguments={"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"} + ) + bad_rename_data = json.loads(bad_rename_result.content[0].text) + assert bad_rename_data.get("success") is False, "Renaming non-existent function should fail" + + # Try to get non-existent function + bad_get_result = await session.call_tool( + "get_function", + arguments={"port": 8192, "name": "nonexistent_function"} + ) + bad_get_data = json.loads(bad_get_result.content[0].text) + assert bad_get_data.get("success") is False, "Getting non-existent function should fail" + + # Try to comment on invalid address + bad_comment_result = await session.call_tool( + "set_decompiler_comment", + arguments={ + "port": 8192, + "address": "0xinvalid", + "comment": "should fail" + } + ) + bad_comment_data = json.loads(bad_comment_result.content[0].text) + assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail" + except Exception as e: logger.error(f"Error testing mutating operations: {e}") raise