#!/usr/bin/env python3 """ Test script for the GhydraMCP HTTP API. This script tests the HTTP endpoints of the Java plugin. """ import json import requests import time import unittest import os # Default Ghidra server port DEFAULT_PORT = 8192 # Get host from environment variable or default to localhost GHYDRAMCP_TEST_HOST = os.getenv('GHYDRAMCP_TEST_HOST') if GHYDRAMCP_TEST_HOST and GHYDRAMCP_TEST_HOST.strip(): BASE_URL = f"http://{GHYDRAMCP_TEST_HOST}:{DEFAULT_PORT}" else: BASE_URL = f"http://localhost:{DEFAULT_PORT}" class GhydraMCPHttpApiTests(unittest.TestCase): """Test cases for the GhydraMCP HTTP API""" def assertStandardSuccessResponse(self, data, expected_result_type=None): """Helper to assert the standard success response structure.""" self.assertIn("success", data, "Response missing 'success' field") self.assertTrue(data["success"], f"API call failed: {data.get('error', 'Unknown error')}") self.assertIn("timestamp", data, "Response missing 'timestamp' field") self.assertIsInstance(data["timestamp"], (int, float), "'timestamp' should be a number") self.assertIn("port", data, "Response missing 'port' field") self.assertEqual(data["port"], DEFAULT_PORT, f"Response port mismatch: expected {DEFAULT_PORT}, got {data['port']}") self.assertIn("result", data, "Response missing 'result' field") if expected_result_type: self.assertIsInstance(data["result"], expected_result_type, f"'result' field type mismatch: expected {expected_result_type}, got {type(data['result'])}") def setUp(self): """Setup before each test""" # Check if the server is running try: response = requests.get(f"{BASE_URL}/info", timeout=2) if response.status_code != 200: self.skipTest("Ghidra server not running or not responding") except requests.exceptions.RequestException: self.skipTest("Ghidra server not running or not accessible") def test_info_endpoint(self): """Test the /info endpoint""" response = requests.get(f"{BASE_URL}/info") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check required fields self.assertIn("port", data) self.assertIn("isBaseInstance", data) self.assertIn("project", data) self.assertIn("file", data) def test_root_endpoint(self): """Test the / endpoint""" response = requests.get(BASE_URL) self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check required fields self.assertIn("port", data) self.assertIn("isBaseInstance", data) self.assertIn("project", data) self.assertIn("file", data) def test_instances_endpoint(self): """Test the /instances endpoint""" response = requests.get(f"{BASE_URL}/instances") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) def test_functions_endpoint(self): """Test the /functions endpoint""" response = requests.get(f"{BASE_URL}/functions") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) # Additional check for function structure if result is not empty result = data["result"] if result: func = result[0] self.assertIn("name", func) self.assertIn("address", func) def test_functions_with_pagination(self): """Test the /functions endpoint with pagination""" response = requests.get(f"{BASE_URL}/functions?offset=0&limit=5") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) # Additional check for function structure and limit if result is not empty result = data["result"] self.assertLessEqual(len(result), 5) if result: func = result[0] self.assertIn("name", func) self.assertIn("address", func) def test_classes_endpoint(self): """Test the /classes endpoint""" response = requests.get(f"{BASE_URL}/classes?offset=0&limit=10") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) # Additional check for class name type if result is not empty result = data["result"] if result: self.assertIsInstance(result[0], str) def test_segments_endpoint(self): """Test the /segments endpoint""" response = requests.get(f"{BASE_URL}/segments?offset=0&limit=10") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) # Additional check for segment structure if result is not empty result = data["result"] if result: seg = result[0] self.assertIn("name", seg) self.assertIn("start", seg) self.assertIn("end", seg) def test_variables_endpoint(self): """Test the /variables endpoint""" response = requests.get(f"{BASE_URL}/variables") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=list) def test_get_function_by_address_endpoint(self): """Test the /get_function_by_address endpoint""" # First get a function address from the functions endpoint response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1") self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data.get("success", False), "API call failed") # Check success first self.assertIn("result", data) result_list = data["result"] self.assertIsInstance(result_list, list) # Skip test if no functions available if not result_list: self.skipTest("No functions available to test get_function_by_address") # Get the address of the first function func_address = result_list[0]["address"] # Now test the get_function_by_address endpoint response = requests.get(f"{BASE_URL}/get_function_by_address?address={func_address}") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=dict) # Additional checks for function details result = data["result"] self.assertIn("name", result) self.assertIn("address", result) self.assertIn("signature", result) self.assertIn("decompilation", result) self.assertIsInstance(result["decompilation"], str) def test_decompile_function_by_address_endpoint(self): """Test the /decompile_function endpoint""" # First get a function address from the functions endpoint response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1") self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data.get("success", False), "API call failed") # Check success first self.assertIn("result", data) result_list = data["result"] self.assertIsInstance(result_list, list) # Skip test if no functions available if not result_list: self.skipTest("No functions available to test decompile_function") # Get the address of the first function func_address = result_list[0]["address"] # Now test the decompile_function endpoint response = requests.get(f"{BASE_URL}/decompile_function?address={func_address}") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=dict) # Additional checks for decompilation result result = data["result"] self.assertIn("decompilation", result) self.assertIsInstance(result["decompilation"], str) def test_function_variables_endpoint(self): """Test the /functions/{name}/variables endpoint""" # First get a function name from the functions endpoint response = requests.get(f"{BASE_URL}/functions?offset=0&limit=1") self.assertEqual(response.status_code, 200) data = response.json() self.assertTrue(data.get("success", False), "API call failed") # Check success first self.assertIn("result", data) result_list = data["result"] self.assertIsInstance(result_list, list) # Skip test if no functions available if not result_list: self.skipTest("No functions available to test function variables") # Get the name of the first function func_name = result_list[0]["name"] # Now test the function variables endpoint response = requests.get(f"{BASE_URL}/functions/{func_name}/variables") self.assertEqual(response.status_code, 200) # Verify response is valid JSON data = response.json() # Check standard response structure self.assertStandardSuccessResponse(data, expected_result_type=dict) # Additional checks for function variables result result = data["result"] self.assertIn("function", result) self.assertIn("variables", result) self.assertIsInstance(result["variables"], list) def test_error_handling(self): """Test error handling for non-existent endpoints""" response = requests.get(f"{BASE_URL}/nonexistent_endpoint") # This should return 404, but some servers might return other codes self.assertNotEqual(response.status_code, 200) def test_get_current_address(self): """Test the /get_current_address endpoint""" response = requests.get(f"{BASE_URL}/get_current_address") self.assertEqual(response.status_code, 200) data = response.json() self.assertStandardSuccessResponse(data, expected_result_type=dict) result = data.get("result", {}) self.assertIn("address", result) self.assertIsInstance(result["address"], str) def test_get_current_function(self): """Test the /get_current_function endpoint""" response = requests.get(f"{BASE_URL}/get_current_function") self.assertEqual(response.status_code, 200) data = response.json() self.assertStandardSuccessResponse(data, expected_result_type=dict) result = data.get("result", {}) self.assertIn("name", result) self.assertIn("address", result) self.assertIn("signature", result) if __name__ == "__main__": unittest.main()