"""Base mixin class for GhydraMCP domain mixins. Provides shared state and utilities for all domain mixins. """ import time from threading import Lock from typing import Any, Dict, Optional from fastmcp import Context from fastmcp.contrib.mcp_mixin import MCPMixin from ..config import get_config from ..core.http_client import safe_get, safe_post, safe_put, safe_patch, safe_delete, simplify_response from ..core.pagination import paginate_response from ..core.logging import log_info, log_debug, log_warning, log_error class GhydraMixinBase(MCPMixin): """Base class for GhydraMCP domain mixins. Provides shared instance state and common utilities. All domain mixins should inherit from this class. """ # Shared state across all mixins _instances: Dict[int, Dict[str, Any]] = {} _instances_lock = Lock() _current_port: Optional[int] = None def __init__(self): """Initialize the mixin with shared state.""" pass @classmethod def get_current_port(cls) -> Optional[int]: """Get the current working instance port.""" return cls._current_port @classmethod def set_current_port(cls, port: int) -> None: """Set the current working instance port.""" cls._current_port = port @classmethod def get_instance_port(cls, port: Optional[int] = None) -> int: """Get instance port, using current if not specified. Args: port: Explicit port (optional) Returns: Port number to use Raises: ValueError: If no port specified and no current instance set """ if port is not None: return port if cls._current_port is not None: return cls._current_port config = get_config() # Try default port default_port = config.quick_discovery_range.start if default_port in cls._instances: return default_port raise ValueError( "No Ghidra instance specified. Use instances_use(port) to set a working instance, " "or pass port= parameter explicitly." ) @classmethod def register_instance(cls, port: int, url: Optional[str] = None) -> str: """Register a Ghidra instance. Args: port: Port number url: Optional URL override Returns: Status message """ config = get_config() if url is None: url = f"http://{config.ghidra_host}:{port}" # Verify instance is responsive try: response = safe_get(port, "") if not response.get("success", False): return f"Failed to connect to Ghidra instance on port {port}" # Check API version api_version = response.get("api_version", 0) if api_version < config.expected_api_version: return ( f"API version mismatch: got {api_version}, " f"expected {config.expected_api_version}" ) with cls._instances_lock: cls._instances[port] = { "url": url, "project": response.get("project", ""), "file": response.get("file", ""), "registered_at": time.time(), } return f"Registered Ghidra instance on port {port}" except Exception as e: return f"Error registering instance: {e}" @classmethod def unregister_instance(cls, port: int) -> str: """Unregister a Ghidra instance. Args: port: Port number Returns: Status message """ with cls._instances_lock: if port in cls._instances: del cls._instances[port] if cls._current_port == port: cls._current_port = None return f"Unregistered Ghidra instance on port {port}" return f"No instance registered on port {port}" @classmethod def list_instances(cls) -> Dict[int, Dict[str, Any]]: """Get all registered instances. Returns: Dict mapping port to instance info """ with cls._instances_lock: return dict(cls._instances) @classmethod def get_instance_info(cls, port: int) -> Optional[Dict[str, Any]]: """Get info for a specific instance. Args: port: Port number Returns: Instance info dict or None """ with cls._instances_lock: return cls._instances.get(port) def _get_session_id(self, ctx: Optional[Context]) -> str: """Extract session ID from FastMCP context. Args: ctx: FastMCP context Returns: Session identifier string """ if ctx is None: return "default" # Try various context attributes if hasattr(ctx, "session") and ctx.session: return str(ctx.session) if hasattr(ctx, "client_id") and ctx.client_id: return str(ctx.client_id) if hasattr(ctx, "request_id") and ctx.request_id: return f"req-{ctx.request_id}" return "default" # Convenience methods for subclasses def safe_get(self, port: int, endpoint: str, params: Optional[Dict] = None) -> Dict: """Make GET request to Ghidra instance.""" return safe_get(port, endpoint, params) def safe_post(self, port: int, endpoint: str, data: Any) -> Dict: """Make POST request to Ghidra instance.""" return safe_post(port, endpoint, data) def safe_put(self, port: int, endpoint: str, data: Dict) -> Dict: """Make PUT request to Ghidra instance.""" return safe_put(port, endpoint, data) def safe_patch(self, port: int, endpoint: str, data: Dict) -> Dict: """Make PATCH request to Ghidra instance.""" return safe_patch(port, endpoint, data) def safe_delete(self, port: int, endpoint: str) -> Dict: """Make DELETE request to Ghidra instance.""" return safe_delete(port, endpoint) def simplify_response(self, response: Dict) -> Dict: """Simplify HATEOAS response.""" return simplify_response(response) def paginate_response( self, data: list, query_params: Dict, tool_name: str, session_id: str = "default", page_size: int = 50, grep: Optional[str] = None, grep_ignorecase: bool = True, return_all: bool = False, fields: Optional[list] = None, ) -> Dict: """Create paginated response with optional field projection.""" return paginate_response( data=data, query_params=query_params, tool_name=tool_name, session_id=session_id, page_size=page_size, grep=grep, grep_ignorecase=grep_ignorecase, return_all=return_all, fields=fields, ) def filtered_paginate( self, data: list, query_params: Dict, tool_name: str, session_id: str = "default", page_size: int = 50, grep: Optional[str] = None, grep_ignorecase: bool = True, return_all: bool = False, fields: Optional[list] = None, ) -> Dict: """Paginate with field projection and budget guard. Convenience wrapper that applies field projection then delegates to paginate_response. Prefer this over paginate_response for any tool that could return large result sets. """ return self.paginate_response( data=data, query_params=query_params, tool_name=tool_name, session_id=session_id, page_size=page_size, grep=grep, grep_ignorecase=grep_ignorecase, return_all=return_all, fields=fields, ) # Async logging helpers async def log_info(self, ctx: Optional[Context], message: str) -> None: """Log info message.""" await log_info(ctx, message) async def log_debug(self, ctx: Optional[Context], message: str) -> None: """Log debug message.""" await log_debug(ctx, message) async def log_warning(self, ctx: Optional[Context], message: str) -> None: """Log warning message.""" await log_warning(ctx, message) async def log_error(self, ctx: Optional[Context], message: str) -> None: """Log error message.""" await log_error(ctx, message)