mcvsphere/src/mcvsphere/middleware.py
Ryan Malloy 6159098963 feat: implement RBAC middleware for OAuth group-based permissions
Add RBACMiddleware that integrates with FastMCP's middleware system to
enforce role-based access control on all tool calls:

- Intercepts every tool call via on_call_tool() hook
- Extracts user groups from OAuth token claims
- Checks permissions using existing permissions.py mappings
- Logs all tool invocations with user identity via audit.py
- Denies access with clear PermissionDeniedError when unauthorized

Permission levels (from permissions.py):
- READ_ONLY: view operations (vsphere-readers)
- POWER_OPS: power/snapshot ops (vsphere-operators)
- VM_LIFECYCLE: create/delete VMs (vsphere-admins)
- HOST_ADMIN: ESXi host management (vsphere-host-admins)
- FULL_ADMIN: guest ops, services (vsphere-super-admins)

Middleware only enabled when OAuth is active (OAUTH_ENABLED=true).
STDIO mode continues to work without permission checking.
2025-12-27 07:37:26 -07:00

341 lines
10 KiB
Python

"""Middleware for permission checking and audit logging.
Provides decorators and hooks for wrapping tool execution with:
- OAuth permission validation
- Audit logging with user identity
- FastMCP middleware integration for RBAC
"""
import logging
import time
from collections.abc import Callable
from functools import wraps
from typing import TYPE_CHECKING, Any
import mcp.types as mt
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import ToolResult
from mcvsphere.audit import (
audit_log,
audit_permission_denied,
get_current_user,
get_user_groups,
set_current_user,
)
from mcvsphere.permissions import (
PermissionDeniedError,
PermissionLevel,
check_permission,
get_required_permission,
)
if TYPE_CHECKING:
from fastmcp.server.context import Context
logger = logging.getLogger(__name__)
def with_permission_check(tool_name: str) -> Callable:
"""Decorator factory for permission checking and audit logging.
Args:
tool_name: Name of the MCP tool being wrapped.
Returns:
Decorator that wraps the tool function with permission checks and audit logging.
Example:
@with_permission_check("power_on")
def power_on(self, name: str) -> str:
...
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
# Get user groups from current context
groups = get_user_groups()
user = get_current_user()
username = "anonymous"
if user:
username = user.get(
"preferred_username", user.get("email", user.get("sub", "unknown"))
)
# Check permission
if not check_permission(tool_name, groups):
required = get_required_permission(tool_name)
audit_permission_denied(tool_name, kwargs, required.value)
raise PermissionDeniedError(username, tool_name, required)
# Execute tool with timing
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, kwargs, result=str(result), duration_ms=duration_ms)
return result
except PermissionDeniedError:
# Re-raise permission errors without additional logging
raise
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, kwargs, error=str(e), duration_ms=duration_ms)
raise
return wrapper
return decorator
def extract_user_from_context(ctx) -> dict[str, Any] | None:
"""Extract user information from FastMCP context.
Args:
ctx: FastMCP Context object.
Returns:
User info dict from OAuth token claims, or None if not authenticated.
"""
if ctx is None:
return None
# Try to get access token from context
try:
# FastMCP stores the access token in request_context
if hasattr(ctx, "request_context") and ctx.request_context:
token = getattr(ctx.request_context, "access_token", None)
if token and hasattr(token, "claims"):
return token.claims
except Exception:
pass
return None
def setup_user_context(ctx) -> None:
"""Set up user context from FastMCP context for the current request.
Call this at the start of request handling to make user info
available throughout the request via get_current_user().
Args:
ctx: FastMCP Context object.
"""
user_info = extract_user_from_context(ctx)
set_current_user(user_info)
class PermissionMiddleware:
"""Middleware for adding permission checks to all tools.
This can be used to wrap mixin tool registration with permission checking.
"""
def __init__(self, oauth_enabled: bool = False):
"""Initialize middleware.
Args:
oauth_enabled: Whether OAuth authentication is enabled.
"""
self.oauth_enabled = oauth_enabled
def wrap_tool(self, tool_name: str, func: Callable) -> Callable:
"""Wrap a tool function with permission checking.
Args:
tool_name: Name of the tool.
func: Original tool function.
Returns:
Wrapped function with permission checks.
"""
if not self.oauth_enabled:
# No auth - just add basic audit logging
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
start_time = time.perf_counter()
try:
result = func(*args, **kwargs)
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, kwargs, result=str(result), duration_ms=duration_ms)
return result
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, kwargs, error=str(e), duration_ms=duration_ms)
raise
return wrapper
# With auth - add permission checking
return with_permission_check(tool_name)(func)
def get_permission_summary() -> dict[str, list[str]]:
"""Get a summary of tools grouped by permission level.
Returns:
Dict mapping permission level names to lists of tool names.
"""
from mcvsphere.permissions import TOOL_PERMISSIONS
summary: dict[str, list[str]] = {level.value: [] for level in PermissionLevel}
for tool_name, level in TOOL_PERMISSIONS.items():
summary[level.value].append(tool_name)
# Sort tool names within each level
for level in summary:
summary[level].sort()
return summary
class RBACMiddleware(Middleware):
"""FastMCP middleware for Role-Based Access Control.
Integrates with FastMCP's middleware system to enforce permissions
on every tool call based on OAuth group memberships.
Example:
mcp = FastMCP("my-server", auth=oauth_provider)
mcp.add_middleware(RBACMiddleware())
"""
def _extract_user_from_context(
self, fastmcp_ctx: "Context | None"
) -> dict[str, Any] | None:
"""Extract user claims from FastMCP context.
Args:
fastmcp_ctx: FastMCP Context object from middleware context.
Returns:
User claims dict from OAuth token, or None if not authenticated.
"""
if fastmcp_ctx is None:
return None
try:
# FastMCP stores access token in request_context
if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context:
token = getattr(fastmcp_ctx.request_context, "access_token", None)
if token and hasattr(token, "claims"):
return token.claims
except Exception as e:
logger.debug("Failed to extract user from context: %s", e)
return None
def _get_username(self, claims: dict[str, Any] | None) -> str:
"""Extract username from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
Username string, or 'anonymous' if no claims.
"""
if not claims:
return "anonymous"
for claim in ("preferred_username", "email", "sub"):
if value := claims.get(claim):
return str(value)
return "unknown"
def _get_groups(self, claims: dict[str, Any] | None) -> list[str]:
"""Extract groups from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
List of group names, or empty list if no claims.
"""
if not claims:
return []
groups = claims.get("groups", [])
if isinstance(groups, list):
return groups
return []
async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Intercept tool calls to enforce RBAC permissions.
Args:
context: Middleware context containing tool call params.
call_next: Next handler in the middleware chain.
Returns:
Tool result if permitted.
Raises:
PermissionDeniedError: If user lacks required permission.
"""
# Extract tool name and arguments from the request
tool_name = context.message.name
tool_args = context.message.arguments or {}
# Get user info from OAuth context
claims = self._extract_user_from_context(context.fastmcp_context)
username = self._get_username(claims)
groups = self._get_groups(claims)
# Set up audit context for this request
set_current_user(claims)
# Check permission
if not check_permission(tool_name, groups):
required = get_required_permission(tool_name)
logger.warning(
"Permission denied: user=%s groups=%s tool=%s required=%s",
username,
groups,
tool_name,
required.value,
)
audit_permission_denied(tool_name, tool_args, required.value)
raise PermissionDeniedError(username, tool_name, required)
# Permission granted - execute tool with timing
start_time = time.perf_counter()
try:
result = await call_next(context)
duration_ms = (time.perf_counter() - start_time) * 1000
# Audit successful execution
# ToolResult can be complex, just log that it succeeded
audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms)
logger.debug(
"Tool executed: user=%s tool=%s duration=%.2fms",
username,
tool_name,
duration_ms,
)
return result
except PermissionDeniedError:
# Re-raise without additional logging
raise
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms)
logger.error(
"Tool failed: user=%s tool=%s error=%s duration=%.2fms",
username,
tool_name,
str(e),
duration_ms,
)
raise