return_all=True on large binaries (1800+ functions) produced 72K char responses that exceeded the MCP tool result limit. Instead of truncating, oversized responses now return a structured summary with sample data, available fields, and actionable instructions for narrowing the query. Three layers of filtering: - Server-side grep: Jython HTTP handlers filter during Ghidra iteration - Field projection: jq-style key selection strips unneeded fields - Token budget guard: responses exceeding 8k tokens return a summary New files: core/filtering.py (project_fields, apply_grep, estimate_and_guard) Modified: config, pagination, base mixin, all 5 domain mixins, headless server
273 lines
8.5 KiB
Python
273 lines
8.5 KiB
Python
"""Base mixin class for GhydraMCP domain mixins.
|
|
|
|
Provides shared state and utilities for all domain mixins.
|
|
"""
|
|
|
|
import time
|
|
from threading import Lock
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastmcp import Context
|
|
from fastmcp.contrib.mcp_mixin import MCPMixin
|
|
|
|
from ..config import get_config
|
|
from ..core.http_client import safe_get, safe_post, safe_put, safe_patch, safe_delete, simplify_response
|
|
from ..core.pagination import paginate_response
|
|
from ..core.logging import log_info, log_debug, log_warning, log_error
|
|
|
|
|
|
class GhydraMixinBase(MCPMixin):
|
|
"""Base class for GhydraMCP domain mixins.
|
|
|
|
Provides shared instance state and common utilities.
|
|
All domain mixins should inherit from this class.
|
|
"""
|
|
|
|
# Shared state across all mixins
|
|
_instances: Dict[int, Dict[str, Any]] = {}
|
|
_instances_lock = Lock()
|
|
_current_port: Optional[int] = None
|
|
|
|
def __init__(self):
|
|
"""Initialize the mixin with shared state."""
|
|
pass
|
|
|
|
@classmethod
|
|
def get_current_port(cls) -> Optional[int]:
|
|
"""Get the current working instance port."""
|
|
return cls._current_port
|
|
|
|
@classmethod
|
|
def set_current_port(cls, port: int) -> None:
|
|
"""Set the current working instance port."""
|
|
cls._current_port = port
|
|
|
|
@classmethod
|
|
def get_instance_port(cls, port: Optional[int] = None) -> int:
|
|
"""Get instance port, using current if not specified.
|
|
|
|
Args:
|
|
port: Explicit port (optional)
|
|
|
|
Returns:
|
|
Port number to use
|
|
|
|
Raises:
|
|
ValueError: If no port specified and no current instance set
|
|
"""
|
|
if port is not None:
|
|
return port
|
|
if cls._current_port is not None:
|
|
return cls._current_port
|
|
config = get_config()
|
|
# Try default port
|
|
default_port = config.quick_discovery_range.start
|
|
if default_port in cls._instances:
|
|
return default_port
|
|
raise ValueError(
|
|
"No Ghidra instance specified. Use instances_use(port) to set a working instance, "
|
|
"or pass port= parameter explicitly."
|
|
)
|
|
|
|
@classmethod
|
|
def register_instance(cls, port: int, url: Optional[str] = None) -> str:
|
|
"""Register a Ghidra instance.
|
|
|
|
Args:
|
|
port: Port number
|
|
url: Optional URL override
|
|
|
|
Returns:
|
|
Status message
|
|
"""
|
|
config = get_config()
|
|
if url is None:
|
|
url = f"http://{config.ghidra_host}:{port}"
|
|
|
|
# Verify instance is responsive
|
|
try:
|
|
response = safe_get(port, "")
|
|
if not response.get("success", False):
|
|
return f"Failed to connect to Ghidra instance on port {port}"
|
|
|
|
# Check API version
|
|
api_version = response.get("api_version", 0)
|
|
if api_version < config.expected_api_version:
|
|
return (
|
|
f"API version mismatch: got {api_version}, "
|
|
f"expected {config.expected_api_version}"
|
|
)
|
|
|
|
with cls._instances_lock:
|
|
cls._instances[port] = {
|
|
"url": url,
|
|
"project": response.get("project", ""),
|
|
"file": response.get("file", ""),
|
|
"registered_at": time.time(),
|
|
}
|
|
|
|
return f"Registered Ghidra instance on port {port}"
|
|
|
|
except Exception as e:
|
|
return f"Error registering instance: {e}"
|
|
|
|
@classmethod
|
|
def unregister_instance(cls, port: int) -> str:
|
|
"""Unregister a Ghidra instance.
|
|
|
|
Args:
|
|
port: Port number
|
|
|
|
Returns:
|
|
Status message
|
|
"""
|
|
with cls._instances_lock:
|
|
if port in cls._instances:
|
|
del cls._instances[port]
|
|
if cls._current_port == port:
|
|
cls._current_port = None
|
|
return f"Unregistered Ghidra instance on port {port}"
|
|
return f"No instance registered on port {port}"
|
|
|
|
@classmethod
|
|
def list_instances(cls) -> Dict[int, Dict[str, Any]]:
|
|
"""Get all registered instances.
|
|
|
|
Returns:
|
|
Dict mapping port to instance info
|
|
"""
|
|
with cls._instances_lock:
|
|
return dict(cls._instances)
|
|
|
|
@classmethod
|
|
def get_instance_info(cls, port: int) -> Optional[Dict[str, Any]]:
|
|
"""Get info for a specific instance.
|
|
|
|
Args:
|
|
port: Port number
|
|
|
|
Returns:
|
|
Instance info dict or None
|
|
"""
|
|
with cls._instances_lock:
|
|
return cls._instances.get(port)
|
|
|
|
def _get_session_id(self, ctx: Optional[Context]) -> str:
|
|
"""Extract session ID from FastMCP context.
|
|
|
|
Args:
|
|
ctx: FastMCP context
|
|
|
|
Returns:
|
|
Session identifier string
|
|
"""
|
|
if ctx is None:
|
|
return "default"
|
|
|
|
# Try various context attributes
|
|
if hasattr(ctx, "session") and ctx.session:
|
|
return str(ctx.session)
|
|
if hasattr(ctx, "client_id") and ctx.client_id:
|
|
return str(ctx.client_id)
|
|
if hasattr(ctx, "request_id") and ctx.request_id:
|
|
return f"req-{ctx.request_id}"
|
|
|
|
return "default"
|
|
|
|
# Convenience methods for subclasses
|
|
def safe_get(self, port: int, endpoint: str, params: Optional[Dict] = None) -> Dict:
|
|
"""Make GET request to Ghidra instance."""
|
|
return safe_get(port, endpoint, params)
|
|
|
|
def safe_post(self, port: int, endpoint: str, data: Any) -> Dict:
|
|
"""Make POST request to Ghidra instance."""
|
|
return safe_post(port, endpoint, data)
|
|
|
|
def safe_put(self, port: int, endpoint: str, data: Dict) -> Dict:
|
|
"""Make PUT request to Ghidra instance."""
|
|
return safe_put(port, endpoint, data)
|
|
|
|
def safe_patch(self, port: int, endpoint: str, data: Dict) -> Dict:
|
|
"""Make PATCH request to Ghidra instance."""
|
|
return safe_patch(port, endpoint, data)
|
|
|
|
def safe_delete(self, port: int, endpoint: str) -> Dict:
|
|
"""Make DELETE request to Ghidra instance."""
|
|
return safe_delete(port, endpoint)
|
|
|
|
def simplify_response(self, response: Dict) -> Dict:
|
|
"""Simplify HATEOAS response."""
|
|
return simplify_response(response)
|
|
|
|
def paginate_response(
|
|
self,
|
|
data: list,
|
|
query_params: Dict,
|
|
tool_name: str,
|
|
session_id: str = "default",
|
|
page_size: int = 50,
|
|
grep: Optional[str] = None,
|
|
grep_ignorecase: bool = True,
|
|
return_all: bool = False,
|
|
fields: Optional[list] = None,
|
|
) -> Dict:
|
|
"""Create paginated response with optional field projection."""
|
|
return paginate_response(
|
|
data=data,
|
|
query_params=query_params,
|
|
tool_name=tool_name,
|
|
session_id=session_id,
|
|
page_size=page_size,
|
|
grep=grep,
|
|
grep_ignorecase=grep_ignorecase,
|
|
return_all=return_all,
|
|
fields=fields,
|
|
)
|
|
|
|
def filtered_paginate(
|
|
self,
|
|
data: list,
|
|
query_params: Dict,
|
|
tool_name: str,
|
|
session_id: str = "default",
|
|
page_size: int = 50,
|
|
grep: Optional[str] = None,
|
|
grep_ignorecase: bool = True,
|
|
return_all: bool = False,
|
|
fields: Optional[list] = None,
|
|
) -> Dict:
|
|
"""Paginate with field projection and budget guard.
|
|
|
|
Convenience wrapper that applies field projection then delegates
|
|
to paginate_response. Prefer this over paginate_response for any
|
|
tool that could return large result sets.
|
|
"""
|
|
return self.paginate_response(
|
|
data=data,
|
|
query_params=query_params,
|
|
tool_name=tool_name,
|
|
session_id=session_id,
|
|
page_size=page_size,
|
|
grep=grep,
|
|
grep_ignorecase=grep_ignorecase,
|
|
return_all=return_all,
|
|
fields=fields,
|
|
)
|
|
|
|
# Async logging helpers
|
|
async def log_info(self, ctx: Optional[Context], message: str) -> None:
|
|
"""Log info message."""
|
|
await log_info(ctx, message)
|
|
|
|
async def log_debug(self, ctx: Optional[Context], message: str) -> None:
|
|
"""Log debug message."""
|
|
await log_debug(ctx, message)
|
|
|
|
async def log_warning(self, ctx: Optional[Context], message: str) -> None:
|
|
"""Log warning message."""
|
|
await log_warning(ctx, message)
|
|
|
|
async def log_error(self, ctx: Optional[Context], message: str) -> None:
|
|
"""Log error message."""
|
|
await log_error(ctx, message)
|