Some checks failed
Build Ghidra Plugin / build (push) Has been cancelled
Refactored Java plugin to use helper methods for consistent JSON success/error responses. Fixed NullPointerException in listVariables. Updated Python tests (HTTP and MCP) to use helper assertions validating the standard response structure.
236 lines
14 KiB
Python
236 lines
14 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script for the GhydraMCP bridge using the MCP client.
|
|
This script tests the bridge by sending MCP requests and handling responses.
|
|
"""
|
|
import json
|
|
import logging
|
|
import sys
|
|
from typing import Any
|
|
|
|
import anyio
|
|
from mcp.client.session import ClientSession
|
|
from mcp.client.stdio import StdioServerParameters, stdio_client
|
|
|
|
# Set up logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("mcp_client_test")
|
|
|
|
async def assert_standard_mcp_success_response(response_content, expected_result_type=None):
|
|
"""Helper to assert the standard success response structure for MCP tool calls."""
|
|
assert response_content, "Response content is empty"
|
|
try:
|
|
data = json.loads(response_content[0].text)
|
|
except (json.JSONDecodeError, IndexError) as e:
|
|
assert False, f"Failed to parse JSON response: {e} - Content: {response_content}"
|
|
|
|
assert "success" in data, "Response missing 'success' field"
|
|
assert data["success"] is True, f"API call failed: {data.get('error', 'Unknown error')}"
|
|
assert "timestamp" in data, "Response missing 'timestamp' field"
|
|
assert isinstance(data["timestamp"], (int, float)), "'timestamp' should be a number"
|
|
assert "port" in data, "Response missing 'port' field"
|
|
# We don't strictly check port number here as it might vary in MCP tests
|
|
assert "result" in data, "Response missing 'result' field"
|
|
if expected_result_type:
|
|
assert isinstance(data["result"], expected_result_type), \
|
|
f"'result' field type mismatch: expected {expected_result_type}, got {type(data['result'])}"
|
|
return data # Return parsed data for further checks if needed
|
|
|
|
async def test_bridge():
|
|
"""Test the bridge using the MCP client"""
|
|
# Configure the server parameters
|
|
server_parameters = StdioServerParameters(
|
|
command=sys.executable,
|
|
args=["bridge_mcp_hydra.py"],
|
|
)
|
|
|
|
# Connect to the bridge
|
|
logger.info("Connecting to bridge...")
|
|
async with stdio_client(server_parameters) as (read_stream, write_stream):
|
|
# Create a session
|
|
logger.info("Creating session...")
|
|
async with ClientSession(read_stream, write_stream) as session:
|
|
# Initialize the session
|
|
logger.info("Initializing session...")
|
|
init_result = await session.initialize()
|
|
logger.info(f"Initialization result: {init_result}")
|
|
|
|
# List tools
|
|
logger.info("Listing tools...")
|
|
tools_result = await session.list_tools()
|
|
logger.info(f"Tools result: {tools_result}")
|
|
|
|
# Call the list_instances tool
|
|
logger.info("Calling list_instances tool...")
|
|
list_instances_result = await session.call_tool("list_instances")
|
|
logger.info(f"List instances result: {list_instances_result}")
|
|
|
|
# Call the discover_instances tool
|
|
logger.info("Calling discover_instances tool...")
|
|
discover_instances_result = await session.call_tool("discover_instances")
|
|
logger.info(f"Discover instances result: {discover_instances_result}")
|
|
|
|
# Call the list_functions tool
|
|
logger.info("Calling list_functions tool...")
|
|
list_functions_result = await session.call_tool(
|
|
"list_functions",
|
|
arguments={"port": 8192, "offset": 0, "limit": 5}
|
|
)
|
|
logger.info(f"List functions result: {list_functions_result}")
|
|
|
|
# Test mutating operations by changing and reverting
|
|
logger.info("Testing mutating operations...")
|
|
|
|
try:
|
|
# Get a known function to test with from list_functions result
|
|
list_funcs = await session.call_tool(
|
|
"list_functions",
|
|
arguments={"port": 8192, "offset": 0, "limit": 5}
|
|
)
|
|
|
|
if not list_funcs or not list_funcs.content:
|
|
logger.warning("No functions found - skipping mutating tests")
|
|
return
|
|
|
|
# Parse the JSON response from list_functions using helper
|
|
try:
|
|
list_funcs_data = await assert_standard_mcp_success_response(list_funcs.content, expected_result_type=list)
|
|
func_list = list_funcs_data.get("result", [])
|
|
if not func_list:
|
|
logger.warning("No functions in list_functions result - skipping mutating tests")
|
|
return
|
|
|
|
# Get first function's name and address
|
|
first_func = func_list[0]
|
|
func_name = first_func.get("name", "")
|
|
func_address = first_func.get("address", "")
|
|
|
|
if not func_name or not func_address:
|
|
logger.warning("No function name/address found in list_functions result - skipping mutating tests")
|
|
return
|
|
except AssertionError as e:
|
|
logger.warning(f"Error processing list_functions data: {e} - skipping mutating tests")
|
|
return
|
|
|
|
# Test function renaming
|
|
original_name = func_name
|
|
test_name = f"{func_name}_test"
|
|
|
|
# Test successful rename operations (These return simple success/message, not full result)
|
|
rename_args = {"port": 8192, "name": original_name, "new_name": test_name}
|
|
logger.info(f"Calling update_function with args: {rename_args}")
|
|
rename_result = await session.call_tool("update_function", arguments=rename_args)
|
|
rename_data = json.loads(rename_result.content[0].text) # Parse simple response
|
|
assert rename_data.get("success") is True, f"Rename failed: {rename_data}"
|
|
logger.info(f"Rename result: {rename_result}")
|
|
|
|
# Verify rename by getting the function
|
|
renamed_func = await session.call_tool("get_function", arguments={"port": 8192, "name": test_name})
|
|
renamed_data = await assert_standard_mcp_success_response(renamed_func.content, expected_result_type=dict)
|
|
assert renamed_data.get("result", {}).get("name") == test_name, f"Renamed function has wrong name: {renamed_data}"
|
|
logger.info(f"Renamed function result: {renamed_func}")
|
|
|
|
# Rename back to original
|
|
revert_args = {"port": 8192, "name": test_name, "new_name": original_name}
|
|
logger.info(f"Calling update_function with args: {revert_args}")
|
|
revert_result = await session.call_tool("update_function", arguments=revert_args)
|
|
revert_data = json.loads(revert_result.content[0].text) # Parse simple response
|
|
assert revert_data.get("success") is True, f"Revert rename failed: {revert_data}"
|
|
logger.info(f"Revert rename result: {revert_result}")
|
|
|
|
# Verify revert by getting the function
|
|
original_func = await session.call_tool("get_function", arguments={"port": 8192, "name": original_name})
|
|
original_data = await assert_standard_mcp_success_response(original_func.content, expected_result_type=dict)
|
|
assert original_data.get("result", {}).get("name") == original_name, f"Original function has wrong name: {original_data}"
|
|
logger.info(f"Original function result: {original_func}")
|
|
|
|
# Test get_function_by_address
|
|
logger.info(f"Calling get_function_by_address with address: {func_address}")
|
|
get_by_addr_result = await session.call_tool("get_function_by_address", arguments={"port": 8192, "address": func_address})
|
|
get_by_addr_data = await assert_standard_mcp_success_response(get_by_addr_result.content, expected_result_type=dict)
|
|
result_data = get_by_addr_data.get("result", {})
|
|
assert "name" in result_data, "Missing name field in get_function_by_address result"
|
|
assert "address" in result_data, "Missing address field in get_function_by_address result"
|
|
assert "signature" in result_data, "Missing signature field in get_function_by_address result"
|
|
assert "decompilation" in result_data, "Missing decompilation field in get_function_by_address result"
|
|
assert result_data.get("name") == original_name, f"Wrong name in get_function_by_address: {result_data.get('name')}"
|
|
logger.info(f"Get function by address result: {get_by_addr_result}")
|
|
|
|
# Test decompile_function_by_address
|
|
logger.info(f"Calling decompile_function_by_address with address: {func_address}")
|
|
decompile_result = await session.call_tool("decompile_function_by_address", arguments={"port": 8192, "address": func_address})
|
|
decompile_data = await assert_standard_mcp_success_response(decompile_result.content, expected_result_type=dict)
|
|
assert "decompilation" in decompile_data.get("result", {}), f"Decompile result missing 'decompilation': {decompile_data}"
|
|
assert isinstance(decompile_data.get("result", {}).get("decompilation", ""), str), f"Decompilation is not a string: {decompile_data}"
|
|
assert len(decompile_data.get("result", {}).get("decompilation", "")) > 0, f"Decompilation result is empty: {decompile_data}"
|
|
logger.info(f"Decompile function by address result: {decompile_result}")
|
|
|
|
# Test list_variables
|
|
logger.info("Calling list_variables tool...")
|
|
list_vars_result = await session.call_tool("list_variables", arguments={"port": 8192, "limit": 10})
|
|
list_vars_data = await assert_standard_mcp_success_response(list_vars_result.content, expected_result_type=list)
|
|
variables_list = list_vars_data.get("result", [])
|
|
if variables_list: # Only validate structure if we get results
|
|
for var in variables_list:
|
|
assert "name" in var, f"Variable missing name: {var}"
|
|
assert "type" in var, f"Variable missing type: {var}"
|
|
assert "dataType" in var, f"Variable missing dataType: {var}"
|
|
logger.info(f"List variables result: {list_vars_result}")
|
|
|
|
# Test successful comment operations (These return simple success/message)
|
|
test_comment = "Test comment from MCP client"
|
|
comment_args = {"port": 8192, "address": func_address, "comment": test_comment}
|
|
logger.info(f"Calling set_decompiler_comment with args: {comment_args}")
|
|
comment_result = await session.call_tool("set_decompiler_comment", arguments=comment_args)
|
|
comment_data = json.loads(comment_result.content[0].text)
|
|
assert comment_data.get("success") is True, f"Add comment failed: {comment_data}"
|
|
logger.info(f"Add comment result: {comment_result}")
|
|
|
|
# Remove comment
|
|
remove_comment_args = {"port": 8192, "address": func_address, "comment": ""}
|
|
logger.info(f"Calling set_decompiler_comment with args: {remove_comment_args}")
|
|
remove_comment_result = await session.call_tool("set_decompiler_comment", arguments=remove_comment_args)
|
|
remove_data = json.loads(remove_comment_result.content[0].text)
|
|
assert remove_data.get("success") is True, f"Remove comment failed: {remove_data}"
|
|
logger.info(f"Remove comment result: {remove_comment_result}")
|
|
|
|
# Test expected failure cases
|
|
# Try to rename non-existent function
|
|
bad_rename_args = {"port": 8192, "name": "nonexistent_function", "new_name": "should_fail"}
|
|
logger.info(f"Calling update_function with args: {bad_rename_args}")
|
|
bad_rename_result = await session.call_tool("update_function", arguments=bad_rename_args)
|
|
logger.info(f"Bad rename result: {bad_rename_result}") # Log the response
|
|
bad_rename_data = json.loads(bad_rename_result.content[0].text)
|
|
assert bad_rename_data.get("success") is False, f"Renaming non-existent function should fail, but got: {bad_rename_data}"
|
|
|
|
# Try to get non-existent function
|
|
bad_get_result = await session.call_tool(
|
|
"get_function",
|
|
arguments={"port": 8192, "name": "nonexistent_function"}
|
|
)
|
|
logger.info(f"Bad get result: {bad_get_result}") # Log the response
|
|
bad_get_data = json.loads(bad_get_result.content[0].text)
|
|
assert bad_get_data.get("success") is False, f"Getting non-existent function should fail, but got: {bad_get_data}"
|
|
|
|
# Try to comment on invalid address
|
|
bad_comment_args = {"port": 8192, "address": "0xinvalid", "comment": "should fail"}
|
|
logger.info(f"Calling set_decompiler_comment with args: {bad_comment_args}")
|
|
bad_comment_result = await session.call_tool("set_decompiler_comment", arguments=bad_comment_args)
|
|
bad_comment_data = json.loads(bad_comment_result.content[0].text)
|
|
assert bad_comment_data.get("success") is False, "Commenting on invalid address should fail"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error testing mutating operations: {e}")
|
|
raise
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
try:
|
|
anyio.run(test_bridge)
|
|
except Exception as e:
|
|
logger.error(f"Error: {e}")
|
|
sys.exit(1)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|