mcghidra/test_http_api.py
2025-04-10 14:42:53 +02:00

311 lines
12 KiB
Python

#!/usr/bin/env python3
"""
Test script for the GhydraMCP HTTP API.
This script tests the HTTP endpoints of the Java plugin.
"""
import json
import requests
import time
import unittest
import os
# Default Ghidra server port
DEFAULT_PORT = 8192
# Get host from environment variable or default to localhost
GHYDRAMCP_TEST_HOST = os.getenv('GHYDRAMCP_TEST_HOST')
if GHYDRAMCP_TEST_HOST and GHYDRAMCP_TEST_HOST.strip():
BASE_URL = f"http://{GHYDRAMCP_TEST_HOST}:{DEFAULT_PORT}"
else:
BASE_URL = f"http://localhost:{DEFAULT_PORT}"
class GhydraMCPHttpApiTests(unittest.TestCase):
"""Test cases for the GhydraMCP HTTP API"""
def assertStandardSuccessResponse(self, data, expected_result_type=None):
"""Helper to assert the standard success response structure."""
self.assertIn("success", data, "Response missing 'success' field")
self.assertTrue(data["success"], f"API call failed: {data.get('error', 'Unknown error')}")
self.assertIn("id", data, "Response missing 'id' field")
self.assertIn("instance", data, "Response missing 'instance' field")
self.assertIn("result", data, "Response missing 'result' field")
if expected_result_type:
self.assertIsInstance(data["result"], expected_result_type, f"'result' field type mismatch: expected {expected_result_type}, got {type(data['result'])}")
def setUp(self):
"""Setup before each test"""
# Check if the server is running
try:
response = requests.get(f"{BASE_URL}/info", timeout=2)
if response.status_code != 200:
self.skipTest("Ghidra server not running or not responding")
except requests.exceptions.RequestException:
self.skipTest("Ghidra server not running or not accessible")
def test_info_endpoint(self):
"""Test the /info endpoint"""
response = requests.get(f"{BASE_URL}/info")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=dict)
# Check required fields in result
result = data["result"]
self.assertIn("isBaseInstance", result)
self.assertIn("project", result)
self.assertIn("file", result)
def test_root_endpoint(self):
"""Test the / endpoint"""
response = requests.get(BASE_URL)
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=dict)
# Check required fields in result
result = data["result"]
self.assertIn("isBaseInstance", result)
self.assertIn("message", result)
def test_instances_endpoint(self):
"""Test the /instances endpoint"""
response = requests.get(f"{BASE_URL}/instances")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
def test_functions_endpoint(self):
"""Test the /functions endpoint"""
response = requests.get(f"{BASE_URL}/functions")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
# Additional check for function structure if result is not empty
result = data["result"]
if result:
func = result[0]
self.assertIn("name", func)
self.assertIn("address", func)
def test_functions_with_pagination(self):
"""Test the /functions endpoint with pagination"""
response = requests.get(f"{BASE_URL}/functions?offset=0&limit=5")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
# Additional check for function structure and limit if result is not empty
result = data["result"]
self.assertLessEqual(len(result), 5)
if result:
func = result[0]
self.assertIn("name", func)
self.assertIn("address", func)
def test_classes_endpoint(self):
"""Test the /classes endpoint"""
response = requests.get(f"{BASE_URL}/classes?offset=0&limit=10")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
# Additional check for class name type if result is not empty
result = data["result"]
if result:
self.assertIsInstance(result[0], str)
def test_segments_endpoint(self):
"""Test the /segments endpoint"""
response = requests.get(f"{BASE_URL}/segments?offset=0&limit=10")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
# Additional check for segment structure if result is not empty
result = data["result"]
if result:
seg = result[0]
self.assertIn("name", seg)
self.assertIn("start", seg)
self.assertIn("end", seg)
def test_variables_endpoint(self):
"""Test the /variables endpoint"""
response = requests.get(f"{BASE_URL}/variables")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=list)
def test_get_function_by_address_endpoint(self):
"""Test the /get_function_by_address endpoint"""
# First get a function address from the functions endpoint
response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertTrue(data.get("success", False), "API call failed") # Check success first
self.assertIn("result", data)
result_list = data["result"]
self.assertIsInstance(result_list, list)
# Skip test if no functions available
if not result_list:
self.skipTest("No functions available to test get_function_by_address")
# Get the address of the first function
func_address = result_list[0]["address"]
# Now test the get_function_by_address endpoint
response = requests.get(f"{BASE_URL}/get_function_by_address?address={func_address}")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=dict)
# Additional checks for function details
result = data["result"]
self.assertIn("name", result)
self.assertIn("address", result)
self.assertIn("signature", result)
self.assertIn("decompilation", result)
self.assertIsInstance(result["decompilation"], str)
def test_decompile_function_by_address_endpoint(self):
"""Test the /decompile_function endpoint"""
# First get a function address from the functions endpoint
response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertTrue(data.get("success", False), "API call failed") # Check success first
self.assertIn("result", data)
result_list = data["result"]
self.assertIsInstance(result_list, list)
# Skip test if no functions available
if not result_list:
self.skipTest("No functions available to test decompile_function")
# Get the address of the first function
func_address = result_list[0]["address"]
# Now test the decompile_function endpoint
response = requests.get(f"{BASE_URL}/decompile_function?address={func_address}")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=dict)
# Additional checks for decompilation result
result = data["result"]
self.assertIn("decompilation", result)
self.assertIsInstance(result["decompilation"], str)
def test_function_variables_endpoint(self):
"""Test the /functions/{name}/variables endpoint"""
# First get a function name from the functions endpoint
response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertTrue(data.get("success", False), "API call failed") # Check success first
self.assertIn("result", data)
result_list = data["result"]
self.assertIsInstance(result_list, list)
# Skip test if no functions available
if not result_list:
self.skipTest("No functions available to test function variables")
# Get the name of the first function
func_name = result_list[0]["name"]
# Now test the function variables endpoint
response = requests.get(f"{BASE_URL}/functions/{func_name}/variables")
self.assertEqual(response.status_code, 200)
# Verify response is valid JSON
data = response.json()
# Check standard response structure
self.assertStandardSuccessResponse(data, expected_result_type=dict)
# Additional checks for function variables result
result = data["result"]
self.assertIn("function", result)
self.assertIn("variables", result)
self.assertIsInstance(result["variables"], list)
def test_error_handling(self):
"""Test error handling for non-existent endpoints"""
response = requests.get(f"{BASE_URL}/nonexistent_endpoint")
# This should return 404, but some servers might return other codes
self.assertNotEqual(response.status_code, 200)
def test_get_current_address(self):
"""Test the /get_current_address endpoint"""
response = requests.get(f"{BASE_URL}/get_current_address")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertStandardSuccessResponse(data, expected_result_type=dict)
result = data.get("result", {})
self.assertIn("address", result)
self.assertIsInstance(result["address"], str)
def test_get_current_function(self):
"""Test the /get_current_function endpoint"""
response = requests.get(f"{BASE_URL}/get_current_function")
self.assertEqual(response.status_code, 200)
data = response.json()
self.assertStandardSuccessResponse(data, expected_result_type=dict)
result = data.get("result", {})
self.assertIn("name", result)
self.assertIn("address", result)
self.assertIn("signature", result)
if __name__ == "__main__":
unittest.main()