Switch all results over to JSON

This commit is contained in:
Teal Bauer 2025-04-04 18:10:45 +02:00
parent ba63ffeb54
commit 14eae14f63
3 changed files with 335 additions and 197 deletions

View File

@ -1,70 +1,66 @@
package eu.starsong.ghidra; package eu.starsong.ghidra;
import ghidra.framework.plugintool.*;
import ghidra.framework.main.ApplicationLevelPlugin;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.GlobalNamespace;
import ghidra.program.model.data.DataType;
import ghidra.program.model.data.DataTypeManager;
import ghidra.program.model.listing.*;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.program.model.pcode.HighVariable;
import ghidra.program.model.pcode.HighSymbol;
import ghidra.program.model.pcode.VarnodeAST;
import ghidra.program.model.pcode.HighFunction;
import ghidra.program.model.pcode.HighFunctionDBUtil;
import ghidra.program.model.pcode.LocalSymbolMap;
import ghidra.program.model.pcode.HighFunctionDBUtil.ReturnCommitOption;
import ghidra.program.model.symbol.*;
import ghidra.app.decompiler.DecompInterface;
import ghidra.app.decompiler.DecompileResults;
import ghidra.app.decompiler.ClangNode;
import ghidra.app.decompiler.ClangTokenGroup;
import ghidra.app.decompiler.ClangVariableToken;
import ghidra.app.plugin.PluginCategoryNames;
import ghidra.app.services.ProgramManager;
import ghidra.app.util.demangler.DemanglerUtil;
import ghidra.framework.model.Project;
import ghidra.framework.model.DomainFile;
import ghidra.framework.plugintool.PluginInfo;
import ghidra.framework.plugintool.util.PluginStatus;
import ghidra.util.Msg;
import ghidra.util.task.ConsoleTaskMonitor;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import javax.swing.SwingUtilities;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
import java.lang.reflect.InvocationTargetException; import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.util.*; import java.util.ArrayList;
import java.util.concurrent.*; import java.util.Collections;
import java.util.concurrent.atomic.*; import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import javax.swing.SwingUtilities;
// For JSON response handling // For JSON response handling
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.JsonObject; import com.google.gson.JsonObject;
import com.sun.net.httpserver.HttpExchange;
import com.sun.net.httpserver.HttpServer;
import ghidra.app.services.CodeViewerService; import ghidra.app.decompiler.DecompInterface;
import ghidra.app.util.PseudoDisassembler; import ghidra.app.decompiler.DecompileResults;
import ghidra.app.cmd.function.SetVariableNameCmd; import ghidra.app.plugin.PluginCategoryNames;
import ghidra.app.services.ProgramManager;
import ghidra.framework.main.ApplicationLevelPlugin;
import ghidra.framework.model.Project;
import ghidra.framework.plugintool.Plugin;
import ghidra.framework.plugintool.PluginInfo;
import ghidra.framework.plugintool.PluginTool;
import ghidra.framework.plugintool.util.PluginStatus;
import ghidra.program.model.address.Address;
import ghidra.program.model.address.GlobalNamespace;
import ghidra.program.model.data.DataType;
import ghidra.program.model.data.DataTypeManager;
import ghidra.program.model.listing.Data;
import ghidra.program.model.listing.DataIterator;
import ghidra.program.model.listing.Function;
import ghidra.program.model.listing.Listing;
import ghidra.program.model.listing.Parameter;
import ghidra.program.model.listing.Program;
import ghidra.program.model.listing.VariableStorage;
import ghidra.program.model.mem.MemoryBlock;
import ghidra.program.model.pcode.HighFunction;
import ghidra.program.model.pcode.HighFunctionDBUtil;
import ghidra.program.model.pcode.HighFunctionDBUtil.ReturnCommitOption;
import ghidra.program.model.pcode.HighSymbol;
import ghidra.program.model.pcode.LocalSymbolMap;
import ghidra.program.model.symbol.Namespace;
import ghidra.program.model.symbol.SourceType; import ghidra.program.model.symbol.SourceType;
import ghidra.program.model.listing.LocalVariableImpl; import ghidra.program.model.symbol.Symbol;
import ghidra.program.model.listing.ParameterImpl; import ghidra.program.model.symbol.SymbolIterator;
import ghidra.util.exception.DuplicateNameException; import ghidra.program.model.symbol.SymbolTable;
import ghidra.util.exception.InvalidInputException; import ghidra.program.model.symbol.SymbolType;
import ghidra.program.util.ProgramLocation; import ghidra.util.Msg;
import ghidra.util.task.TaskMonitor; import ghidra.util.task.ConsoleTaskMonitor;
import ghidra.program.model.pcode.Varnode;
import ghidra.program.model.data.PointerDataType;
import ghidra.program.model.data.Undefined1DataType;
import ghidra.program.model.listing.Variable;
import ghidra.app.decompiler.component.DecompilerUtils;
import ghidra.app.decompiler.ClangToken;
@PluginInfo( @PluginInfo(
status = PluginStatus.RELEASED, status = PluginStatus.RELEASED,
@ -97,8 +93,6 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
} }
Msg.info(this, "Marker");
// Log to both console and log file // Log to both console and log file
Msg.info(this, "GhydraMCPPlugin loaded on port " + port); Msg.info(this, "GhydraMCPPlugin loaded on port " + port);
System.out.println("[GhydraMCP] Plugin loaded on port " + port); System.out.println("[GhydraMCP] Plugin loaded on port " + port);
@ -176,12 +170,28 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
Map<String, String> params = parsePostParams(exchange); Map<String, String> params = parsePostParams(exchange);
if (params.containsKey("newName")) { if (params.containsKey("newName")) {
// Rename variable // Rename variable
String result = renameVariable(functionName, variableName, params.get("newName")); boolean success = renameVariable(functionName, variableName, params.get("newName"));
sendResponse(exchange, result); JsonObject response = new JsonObject();
response.addProperty("success", success);
response.addProperty("message", success ? "Variable renamed successfully" : "Failed to rename variable");
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
if (!success) {
exchange.sendResponseHeaders(400, 0);
}
sendJsonResponse(exchange, response);
} else if (params.containsKey("dataType")) { } else if (params.containsKey("dataType")) {
// Retype variable // Retype variable
String result = retypeVariable(functionName, variableName, params.get("dataType")); boolean success = retypeVariable(functionName, variableName, params.get("dataType"));
sendResponse(exchange, result); JsonObject response = new JsonObject();
response.addProperty("success", success);
response.addProperty("message", success ? "Variable retyped successfully" : "Failed to retype variable");
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
if (!success) {
exchange.sendResponseHeaders(400, 0);
}
sendJsonResponse(exchange, response);
} else { } else {
sendResponse(exchange, "Missing required parameter: newName or dataType"); sendResponse(exchange, "Missing required parameter: newName or dataType");
} }
@ -195,62 +205,41 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} else if ("PUT".equals(exchange.getRequestMethod())) { } else if ("PUT".equals(exchange.getRequestMethod())) {
Map<String, String> params = parsePostParams(exchange); Map<String, String> params = parsePostParams(exchange);
String newName = params.get("newName"); String newName = params.get("newName");
String response = renameFunction(functionName, newName) boolean success = renameFunction(functionName, newName);
? "Renamed successfully" : "Rename failed"; JsonObject response = new JsonObject();
sendResponse(exchange, response); response.addProperty("success", success);
response.addProperty("message", success ? "Renamed successfully" : "Rename failed");
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
if (!success) {
exchange.sendResponseHeaders(400, 0);
}
sendJsonResponse(exchange, response);
} else { } else {
exchange.sendResponseHeaders(405, -1); // Method Not Allowed exchange.sendResponseHeaders(405, -1); // Method Not Allowed
} }
} }
}); });
// Class resources with detailed logging // Class resources
server.createContext("/classes", exchange -> { server.createContext("/classes", exchange -> {
try {
if ("GET".equals(exchange.getRequestMethod())) { if ("GET".equals(exchange.getRequestMethod())) {
try { try {
Map<String, String> qparams = parseQueryParams(exchange); Map<String, String> qparams = parseQueryParams(exchange);
int offset = parseIntOrDefault(qparams.get("offset"), 0); int offset = parseIntOrDefault(qparams.get("offset"), 0);
int limit = parseIntOrDefault(qparams.get("limit"), 100); int limit = parseIntOrDefault(qparams.get("limit"), 100);
sendJsonResponse(exchange, getAllClassNames(offset, limit));
String result = getAllClassNames(offset, limit);
JsonObject json = new JsonObject();
json.addProperty("success", true);
json.addProperty("result", result);
json.addProperty("timestamp", System.currentTimeMillis());
json.addProperty("port", this.port);
Gson gson = new Gson();
String jsonStr = gson.toJson(json);
exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8");
byte[] bytes = jsonStr.getBytes(StandardCharsets.UTF_8);
exchange.sendResponseHeaders(200, bytes.length);
OutputStream os = exchange.getResponseBody();
os.write(bytes);
os.flush();
os.close();
} catch (Exception e) { } catch (Exception e) {
Msg.error(this, "/classes: Error in request processing: " + e.getMessage(), e); Msg.error(this, "/classes: Error in request processing", e);
try { try {
sendErrorResponse(exchange, 500, "Internal server error: " + e.getMessage()); sendErrorResponse(exchange, 500, "Internal server error");
} catch (IOException ioe) { } catch (IOException ioe) {
Msg.error(this, "/classes: Failed to send error response: " + ioe.getMessage(), ioe); Msg.error(this, "/classes: Failed to send error response", ioe);
} }
} }
} else { } else {
exchange.sendResponseHeaders(405, -1); // Method Not Allowed exchange.sendResponseHeaders(405, -1); // Method Not Allowed
} }
} catch (Exception e) {
Msg.error(this, "/classes: Unhandled error: " + e.getMessage(), e);
}
}); });
// Memory segments // Memory segments
@ -309,8 +298,16 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
sendResponse(exchange, listDefinedData(offset, limit)); sendResponse(exchange, listDefinedData(offset, limit));
} else if ("PUT".equals(exchange.getRequestMethod())) { } else if ("PUT".equals(exchange.getRequestMethod())) {
Map<String, String> params = parsePostParams(exchange); Map<String, String> params = parsePostParams(exchange);
renameDataAtAddress(params.get("address"), params.get("newName")); boolean success = renameDataAtAddress(params.get("address"), params.get("newName"));
sendResponse(exchange, "Rename data attempted"); JsonObject response = new JsonObject();
response.addProperty("success", success);
response.addProperty("message", success ? "Data renamed successfully" : "Failed to rename data");
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
if (!success) {
exchange.sendResponseHeaders(400, 0);
}
sendJsonResponse(exchange, response);
} else { } else {
exchange.sendResponseHeaders(405, -1); // Method Not Allowed exchange.sendResponseHeaders(405, -1); // Method Not Allowed
} }
@ -484,18 +481,38 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
private String getAllFunctionNames(int offset, int limit) { private String getAllFunctionNames(int offset, int limit) {
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return "No program loaded"; if (program == null) return "{\"success\":false,\"error\":\"No program loaded\"}";
List<String> names = new ArrayList<>(); List<Map<String, String>> functions = new ArrayList<>();
for (Function f : program.getFunctionManager().getFunctions(true)) { for (Function f : program.getFunctionManager().getFunctions(true)) {
names.add(f.getName() + " @ " + f.getEntryPoint()); Map<String, String> func = new HashMap<>();
} func.put("name", f.getName());
return paginateList(names, offset, limit); func.put("address", f.getEntryPoint().toString());
functions.add(func);
} }
private String getAllClassNames(int offset, int limit) { // Apply pagination
int start = Math.max(0, offset);
int end = Math.min(functions.size(), offset + limit);
List<Map<String, String>> paginated = functions.subList(start, end);
Gson gson = new Gson();
JsonObject response = new JsonObject();
response.addProperty("success", true);
response.add("result", gson.toJsonTree(paginated));
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
return gson.toJson(response);
}
private JsonObject getAllClassNames(int offset, int limit) {
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return "No program loaded"; if (program == null) {
JsonObject error = new JsonObject();
error.addProperty("success", false);
error.addProperty("error", "No program loaded");
return error;
}
Set<String> classNames = new HashSet<>(); Set<String> classNames = new HashSet<>();
for (Symbol symbol : program.getSymbolTable().getAllSymbols(true)) { for (Symbol symbol : program.getSymbolTable().getAllSymbols(true)) {
@ -504,21 +521,47 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
classNames.add(ns.getName()); classNames.add(ns.getName());
} }
} }
// Convert set to list for pagination
// Convert to sorted list and paginate
List<String> sorted = new ArrayList<>(classNames); List<String> sorted = new ArrayList<>(classNames);
Collections.sort(sorted); Collections.sort(sorted);
return paginateList(sorted, offset, limit); int start = Math.max(0, offset);
int end = Math.min(sorted.size(), offset + limit);
List<String> paginated = sorted.subList(start, end);
JsonObject response = new JsonObject();
response.addProperty("success", true);
response.add("result", new Gson().toJsonTree(paginated));
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
return response;
} }
private String listSegments(int offset, int limit) { private String listSegments(int offset, int limit) {
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return "No program loaded"; if (program == null) return "{\"success\":false,\"error\":\"No program loaded\"}";
List<String> lines = new ArrayList<>(); List<Map<String, String>> segments = new ArrayList<>();
for (MemoryBlock block : program.getMemory().getBlocks()) { for (MemoryBlock block : program.getMemory().getBlocks()) {
lines.add(String.format("%s: %s - %s", block.getName(), block.getStart(), block.getEnd())); Map<String, String> seg = new HashMap<>();
seg.put("name", block.getName());
seg.put("start", block.getStart().toString());
seg.put("end", block.getEnd().toString());
segments.add(seg);
} }
return paginateList(lines, offset, limit);
// Apply pagination
int start = Math.max(0, offset);
int end = Math.min(segments.size(), offset + limit);
List<Map<String, String>> paginated = segments.subList(start, end);
Gson gson = new Gson();
JsonObject response = new JsonObject();
response.addProperty("success", true);
response.add("result", gson.toJsonTree(paginated));
response.addProperty("timestamp", System.currentTimeMillis());
response.addProperty("port", this.port);
return gson.toJson(response);
} }
private String listImports(int offset, int limit) { private String listImports(int offset, int limit) {
@ -670,10 +713,11 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
return successFlag.get(); return successFlag.get();
} }
private void renameDataAtAddress(String addressStr, String newName) { private boolean renameDataAtAddress(String addressStr, String newName) {
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return; if (program == null) return false;
AtomicBoolean successFlag = new AtomicBoolean(false);
try { try {
SwingUtilities.invokeAndWait(() -> { SwingUtilities.invokeAndWait(() -> {
int tx = program.startTransaction("Rename data"); int tx = program.startTransaction("Rename data");
@ -686,8 +730,10 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
Symbol symbol = symTable.getPrimarySymbol(addr); Symbol symbol = symTable.getPrimarySymbol(addr);
if (symbol != null) { if (symbol != null) {
symbol.setName(newName, SourceType.USER_DEFINED); symbol.setName(newName, SourceType.USER_DEFINED);
successFlag.set(true);
} else { } else {
symTable.createLabel(addr, newName, SourceType.USER_DEFINED); symTable.createLabel(addr, newName, SourceType.USER_DEFINED);
successFlag.set(true);
} }
} }
} }
@ -695,13 +741,14 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
Msg.error(this, "Rename data error", e); Msg.error(this, "Rename data error", e);
} }
finally { finally {
program.endTransaction(tx, true); program.endTransaction(tx, successFlag.get());
} }
}); });
} }
catch (InterruptedException | InvocationTargetException e) { catch (InterruptedException | InvocationTargetException e) {
Msg.error(this, "Failed to execute rename data on Swing thread", e); Msg.error(this, "Failed to execute rename data on Swing thread", e);
} }
return successFlag.get();
} }
// ---------------------------------------------------------------------------------- // ----------------------------------------------------------------------------------
@ -789,9 +836,9 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
} }
private String renameVariable(String functionName, String oldName, String newName) { private boolean renameVariable(String functionName, String oldName, String newName) {
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return "No program loaded"; if (program == null) return false;
DecompInterface decomp = new DecompInterface(); DecompInterface decomp = new DecompInterface();
decomp.openProgram(program); decomp.openProgram(program);
@ -805,22 +852,22 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
if (func == null) { if (func == null) {
return "Function not found"; return false;
} }
DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor()); DecompileResults result = decomp.decompileFunction(func, 30, new ConsoleTaskMonitor());
if (result == null || !result.decompileCompleted()) { if (result == null || !result.decompileCompleted()) {
return "Decompilation failed"; return false;
} }
HighFunction highFunction = result.getHighFunction(); HighFunction highFunction = result.getHighFunction();
if (highFunction == null) { if (highFunction == null) {
return "Decompilation failed (no high function)"; return false;
} }
LocalSymbolMap localSymbolMap = highFunction.getLocalSymbolMap(); LocalSymbolMap localSymbolMap = highFunction.getLocalSymbolMap();
if (localSymbolMap == null) { if (localSymbolMap == null) {
return "Decompilation failed (no local symbol map)"; return false;
} }
HighSymbol highSymbol = null; HighSymbol highSymbol = null;
@ -833,12 +880,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
highSymbol = symbol; highSymbol = symbol;
} }
if (symbolName.equals(newName)) { if (symbolName.equals(newName)) {
return "Error: A variable with name '" + newName + "' already exists in this function"; return false;
} }
} }
if (highSymbol == null) { if (highSymbol == null) {
return "Variable not found"; return false;
} }
boolean commitRequired = checkFullCommit(highSymbol, highFunction); boolean commitRequired = checkFullCommit(highSymbol, highFunction);
@ -871,11 +918,10 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
}); });
} catch (InterruptedException | InvocationTargetException e) { } catch (InterruptedException | InvocationTargetException e) {
String errorMsg = "Failed to execute rename on Swing thread: " + e.getMessage(); Msg.error(this, "Failed to execute rename on Swing thread", e);
Msg.error(this, errorMsg, e); return false;
return errorMsg;
} }
return successFlag.get() ? "Variable renamed" : "Failed to rename variable"; return successFlag.get();
} }
/** /**
@ -914,15 +960,15 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
return false; return false;
} }
private String retypeVariable(String functionName, String varName, String dataTypeName) { private boolean retypeVariable(String functionName, String varName, String dataTypeName) {
if (varName == null || varName.isEmpty() || dataTypeName == null || dataTypeName.isEmpty()) { if (varName == null || varName.isEmpty() || dataTypeName == null || dataTypeName.isEmpty()) {
return "Both variable name and data type are required"; return false;
} }
Program program = getCurrentProgram(); Program program = getCurrentProgram();
if (program == null) return "No program loaded"; if (program == null) return false;
AtomicReference<String> result = new AtomicReference<>("Variable retype failed"); AtomicBoolean result = new AtomicBoolean(false);
try { try {
SwingUtilities.invokeAndWait(() -> { SwingUtilities.invokeAndWait(() -> {
@ -930,7 +976,6 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
try { try {
Function function = findFunctionByName(program, functionName); Function function = findFunctionByName(program, functionName);
if (function == null) { if (function == null) {
result.set("Function not found: " + functionName);
return; return;
} }
@ -940,13 +985,11 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
DecompileResults decompRes = decomp.decompileFunction(function, 30, new ConsoleTaskMonitor()); DecompileResults decompRes = decomp.decompileFunction(function, 30, new ConsoleTaskMonitor());
if (decompRes == null || !decompRes.decompileCompleted()) { if (decompRes == null || !decompRes.decompileCompleted()) {
result.set("Failed to decompile function: " + functionName);
return; return;
} }
HighFunction highFunction = decompRes.getHighFunction(); HighFunction highFunction = decompRes.getHighFunction();
if (highFunction == null) { if (highFunction == null) {
result.set("Failed to get high function");
return; return;
} }
@ -963,14 +1006,12 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
if (targetSymbol == null) { if (targetSymbol == null) {
result.set("Variable not found: " + varName);
return; return;
} }
// Find the data type by name // Find the data type by name
DataType dataType = findDataType(program, dataTypeName); DataType dataType = findDataType(program, dataTypeName);
if (dataType == null) { if (dataType == null) {
result.set("Data type not found: " + dataTypeName);
return; return;
} }
@ -978,17 +1019,17 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
HighFunctionDBUtil.updateDBVariable(targetSymbol, targetSymbol.getName(), dataType, HighFunctionDBUtil.updateDBVariable(targetSymbol, targetSymbol.getName(), dataType,
SourceType.USER_DEFINED); SourceType.USER_DEFINED);
result.set("Variable '" + varName + "' retyped to '" + dataTypeName + "'"); result.set(true);
} catch (Exception e) { } catch (Exception e) {
Msg.error(this, "Error retyping variable", e); Msg.error(this, "Error retyping variable", e);
result.set("Error: " + e.getMessage()); result.set(false);
} finally { } finally {
program.endTransaction(tx, true); program.endTransaction(tx, true);
} }
}); });
} catch (InterruptedException | InvocationTargetException e) { } catch (InterruptedException | InvocationTargetException e) {
Msg.error(this, "Failed to execute on Swing thread", e); Msg.error(this, "Failed to execute on Swing thread", e);
result.set("Error: " + e.getMessage()); result.set(false);
} }
return result.get(); return result.get();
@ -1277,12 +1318,22 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
private void sendResponse(HttpExchange exchange, Object response) throws IOException { private void sendResponse(HttpExchange exchange, Object response) throws IOException {
if (response instanceof String && ((String)response).startsWith("{")) {
// Already JSON formatted, send as-is
byte[] bytes = ((String)response).getBytes(StandardCharsets.UTF_8);
exchange.getResponseHeaders().set("Content-Type", "application/json; charset=utf-8");
exchange.sendResponseHeaders(200, bytes.length);
try (OutputStream os = exchange.getResponseBody()) {
os.write(bytes);
}
} else {
// Wrap in standard response format
JsonObject json = new JsonObject(); JsonObject json = new JsonObject();
json.addProperty("success", true); json.addProperty("success", true);
if (response instanceof String) { if (response instanceof String) {
json.addProperty("result", (String)response); json.addProperty("result", (String)response);
} else { } else {
json.addProperty("data", response.toString()); json.add("result", new Gson().toJsonTree(response));
} }
json.addProperty("timestamp", System.currentTimeMillis()); json.addProperty("timestamp", System.currentTimeMillis());
json.addProperty("port", this.port); json.addProperty("port", this.port);
@ -1293,6 +1344,7 @@ public class GhydraMCPPlugin extends Plugin implements ApplicationLevelPlugin {
} }
sendJsonResponse(exchange, json); sendJsonResponse(exchange, json);
} }
}
private void sendJsonResponse(HttpExchange exchange, JsonObject jsonObj) throws IOException { private void sendJsonResponse(HttpExchange exchange, JsonObject jsonObj) throws IOException {
try { try {

View File

@ -84,8 +84,13 @@ class GhydraMCPHttpApiTests(unittest.TestCase):
self.assertIn("timestamp", data) self.assertIn("timestamp", data)
self.assertIn("port", data) self.assertIn("port", data)
# Check that we have either result or data # Check result is an array of function objects
self.assertTrue("result" in data or "data" in data) self.assertIn("result", data)
self.assertIsInstance(data["result"], list)
if data["result"]: # If there are functions
func = data["result"][0]
self.assertIn("name", func)
self.assertIn("address", func)
def test_functions_with_pagination(self): def test_functions_with_pagination(self):
"""Test the /functions endpoint with pagination""" """Test the /functions endpoint with pagination"""
@ -101,9 +106,18 @@ class GhydraMCPHttpApiTests(unittest.TestCase):
self.assertIn("timestamp", data) self.assertIn("timestamp", data)
self.assertIn("port", data) self.assertIn("port", data)
# Check result is an array of max 5 function objects
self.assertIn("result", data)
self.assertIsInstance(data["result"], list)
self.assertLessEqual(len(data["result"]), 5)
if data["result"]: # If there are functions
func = data["result"][0]
self.assertIn("name", func)
self.assertIn("address", func)
def test_classes_endpoint(self): def test_classes_endpoint(self):
"""Test the /classes endpoint""" """Test the /classes endpoint"""
response = requests.get(f"{BASE_URL}/classes") response = requests.get(f"{BASE_URL}/classes?offset=0&limit=10")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
# Verify response is valid JSON # Verify response is valid JSON
@ -115,9 +129,15 @@ class GhydraMCPHttpApiTests(unittest.TestCase):
self.assertIn("timestamp", data) self.assertIn("timestamp", data)
self.assertIn("port", data) self.assertIn("port", data)
# Check result is an array of class names
self.assertIn("result", data)
self.assertIsInstance(data["result"], list)
if data["result"]: # If there are classes
self.assertIsInstance(data["result"][0], str)
def test_segments_endpoint(self): def test_segments_endpoint(self):
"""Test the /segments endpoint""" """Test the /segments endpoint"""
response = requests.get(f"{BASE_URL}/segments") response = requests.get(f"{BASE_URL}/segments?offset=0&limit=10")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
# Verify response is valid JSON # Verify response is valid JSON
@ -129,6 +149,15 @@ class GhydraMCPHttpApiTests(unittest.TestCase):
self.assertIn("timestamp", data) self.assertIn("timestamp", data)
self.assertIn("port", data) self.assertIn("port", data)
# Check result is an array of segment objects
self.assertIn("result", data)
self.assertIsInstance(data["result"], list)
if data["result"]: # If there are segments
seg = data["result"][0]
self.assertIn("name", seg)
self.assertIn("start", seg)
self.assertIn("end", seg)
def test_variables_endpoint(self): def test_variables_endpoint(self):
"""Test the /variables endpoint""" """Test the /variables endpoint"""
response = requests.get(f"{BASE_URL}/variables") response = requests.get(f"{BASE_URL}/variables")

View File

@ -3,13 +3,12 @@
Test script for the GhydraMCP bridge using the MCP client. Test script for the GhydraMCP bridge using the MCP client.
This script tests the bridge by sending MCP requests and handling responses. This script tests the bridge by sending MCP requests and handling responses.
""" """
import asyncio import json
import logging import logging
import sys import sys
from typing import Any from typing import Any
import anyio import anyio
from mcp.client.session import ClientSession from mcp.client.session import ClientSession
from mcp.client.stdio import StdioServerParameters, stdio_client from mcp.client.stdio import StdioServerParameters, stdio_client
@ -69,31 +68,27 @@ async def test_bridge():
arguments={"port": 8192, "offset": 0, "limit": 5} arguments={"port": 8192, "offset": 0, "limit": 5}
) )
if not hasattr(list_funcs, "result") or not hasattr(list_funcs.result, "content") or not list_funcs.result.content: if not list_funcs or not list_funcs.content:
logger.warning("No functions found - skipping mutating tests") logger.warning("No functions found - skipping mutating tests")
return return
# The list_functions result contains a JSON string in the text field # The list_functions result contains the function data directly
func_json = list_funcs.result.content[0].get("text", "") if not list_funcs.content:
if not func_json:
logger.warning("No function data found - skipping mutating tests") logger.warning("No function data found - skipping mutating tests")
return return
# Parse the JSON response
try: try:
# Parse the JSON to get the function list func_data = json.loads(list_funcs.content[0].text)
func_data = json.loads(func_json) func_list = func_data.get("result", [])
func_list = func_data.get("result", "").split("\n")
if not func_list: if not func_list:
logger.warning("No functions in result - skipping mutating tests") logger.warning("No functions in result - skipping mutating tests")
return return
# Extract first function name (format: "name @ address") # Get first function's name
func_name = func_list[0].split("@")[0].strip() func_name = func_list[0].get("name", "")
except (json.JSONDecodeError, AttributeError) as e:
logger.warning(f"Error parsing function data: {e} - skipping mutating tests")
return
if not func_name: if not func_name:
logger.warning("Could not parse function name - skipping mutating tests") logger.warning("No function name found - skipping mutating tests")
return return
# Get full function details # Get full function details
@ -102,27 +97,50 @@ async def test_bridge():
arguments={"port": 8192, "name": func_name} arguments={"port": 8192, "name": func_name}
) )
if not hasattr(func_details, "result") or not hasattr(func_details.result, "content") or not func_details.result.content: if not func_details.content or not func_details.content[0].text:
logger.warning("Could not get function details - skipping mutating tests") logger.warning("Could not get function details - skipping mutating tests")
return return
func_content = func_details.result.content[0] # Parse function details - response is the decompiled code text
func_name = func_content.get("text", "").split("\n")[0] func_text = func_details.content[0].text
func_address = func_content.get("address", "") if not func_text:
logger.warning("Empty function details - skipping mutating tests")
return
# First line contains name and address
first_line = func_text.split('\n')[0]
if not first_line:
logger.warning("Invalid function format - skipping mutating tests")
return
# Extract name and address from first line
parts = first_line.split()
if len(parts) < 2:
logger.warning("Could not parse function details - skipping mutating tests")
return
func_name = parts[1] # Second part is function name
func_address = parts[0] # First part is address
if not func_name or not func_address: if not func_name or not func_address:
logger.warning("Could not get valid function name/address - skipping mutating tests") logger.warning("Could not get valid function name/address - skipping mutating tests")
return return
except json.JSONDecodeError as e:
logger.warning(f"Error parsing function data: {e} - skipping mutating tests")
return
# Test function renaming # Test function renaming
original_name = func_name original_name = func_name
test_name = f"{func_name}_test" test_name = f"{func_name}_test"
# Rename to test name # Test successful rename operations
rename_result = await session.call_tool( rename_result = await session.call_tool(
"update_function", "update_function",
arguments={"port": 8192, "name": original_name, "new_name": test_name} arguments={"port": 8192, "name": original_name, "new_name": test_name}
) )
rename_data = json.loads(rename_result.content[0].text)
assert rename_data.get("success") is True, f"Rename failed: {rename_data}"
logger.info(f"Rename result: {rename_result}") logger.info(f"Rename result: {rename_result}")
# Verify rename # Verify rename
@ -130,6 +148,8 @@ async def test_bridge():
"get_function", "get_function",
arguments={"port": 8192, "name": test_name} arguments={"port": 8192, "name": test_name}
) )
renamed_data = json.loads(renamed_func.content[0].text)
assert renamed_data.get("success") is True, f"Get renamed function failed: {renamed_data}"
logger.info(f"Renamed function result: {renamed_func}") logger.info(f"Renamed function result: {renamed_func}")
# Rename back to original # Rename back to original
@ -137,6 +157,8 @@ async def test_bridge():
"update_function", "update_function",
arguments={"port": 8192, "name": test_name, "new_name": original_name} arguments={"port": 8192, "name": test_name, "new_name": original_name}
) )
revert_data = json.loads(revert_result.content[0].text)
assert revert_data.get("success") is True, f"Revert rename failed: {revert_data}"
logger.info(f"Revert rename result: {revert_result}") logger.info(f"Revert rename result: {revert_result}")
# Verify revert # Verify revert
@ -144,9 +166,11 @@ async def test_bridge():
"get_function", "get_function",
arguments={"port": 8192, "name": original_name} arguments={"port": 8192, "name": original_name}
) )
original_data = json.loads(original_func.content[0].text)
assert original_data.get("success") is True, f"Get original function failed: {original_data}"
logger.info(f"Original function result: {original_func}") logger.info(f"Original function result: {original_func}")
# Test adding/removing comment # Test successful comment operations
test_comment = "Test comment from MCP client" test_comment = "Test comment from MCP client"
comment_result = await session.call_tool( comment_result = await session.call_tool(
"set_decompiler_comment", "set_decompiler_comment",
@ -156,6 +180,8 @@ async def test_bridge():
"comment": test_comment "comment": test_comment
} }
) )
comment_data = json.loads(comment_result.content[0].text)
assert comment_data.get("success") is True, f"Add comment failed: {comment_data}"
logger.info(f"Add comment result: {comment_result}") logger.info(f"Add comment result: {comment_result}")
# Remove comment # Remove comment
@ -167,8 +193,39 @@ async def test_bridge():
"comment": "" "comment": ""
} }
) )
remove_data = json.loads(remove_comment_result.content[0].text)
assert remove_data.get("success") is True, f"Remove comment failed: {remove_data}"
logger.info(f"Remove comment result: {remove_comment_result}") logger.info(f"Remove comment result: {remove_comment_result}")
# Test expected failure cases
# Try to rename non-existent function
bad_rename_result = await session.call_tool(
"update_function",
arguments={"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"}
)
bad_rename_data = json.loads(bad_rename_result.content[0].text)
assert bad_rename_data.get("success") is False, "Renaming non-existent function should fail"
# Try to get non-existent function
bad_get_result = await session.call_tool(
"get_function",
arguments={"port": 8192, "name": "nonexistent_function"}
)
bad_get_data = json.loads(bad_get_result.content[0].text)
assert bad_get_data.get("success") is False, "Getting non-existent function should fail"
# Try to comment on invalid address
bad_comment_result = await session.call_tool(
"set_decompiler_comment",
arguments={
"port": 8192,
"address": "0xinvalid",
"comment": "should fail"
}
)
bad_comment_data = json.loads(bad_comment_result.content[0].text)
assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail"
except Exception as e: except Exception as e:
logger.error(f"Error testing mutating operations: {e}") logger.error(f"Error testing mutating operations: {e}")
raise raise