From 6159098963c4ce175114ab1879bb2c839c2faf5b Mon Sep 17 00:00:00 2001 From: Ryan Malloy Date: Sat, 27 Dec 2025 07:37:26 -0700 Subject: [PATCH] feat: implement RBAC middleware for OAuth group-based permissions Add RBACMiddleware that integrates with FastMCP's middleware system to enforce role-based access control on all tool calls: - Intercepts every tool call via on_call_tool() hook - Extracts user groups from OAuth token claims - Checks permissions using existing permissions.py mappings - Logs all tool invocations with user identity via audit.py - Denies access with clear PermissionDeniedError when unauthorized Permission levels (from permissions.py): - READ_ONLY: view operations (vsphere-readers) - POWER_OPS: power/snapshot ops (vsphere-operators) - VM_LIFECYCLE: create/delete VMs (vsphere-admins) - HOST_ADMIN: ESXi host management (vsphere-host-admins) - FULL_ADMIN: guest ops, services (vsphere-super-admins) Middleware only enabled when OAuth is active (OAUTH_ENABLED=true). STDIO mode continues to work without permission checking. --- src/mcvsphere/middleware.py | 161 +++++++++++++++++++++++++++++++++++- src/mcvsphere/server.py | 9 +- 2 files changed, 168 insertions(+), 2 deletions(-) diff --git a/src/mcvsphere/middleware.py b/src/mcvsphere/middleware.py index 183a545..c61aa34 100644 --- a/src/mcvsphere/middleware.py +++ b/src/mcvsphere/middleware.py @@ -3,12 +3,18 @@ Provides decorators and hooks for wrapping tool execution with: - OAuth permission validation - Audit logging with user identity +- FastMCP middleware integration for RBAC """ +import logging import time from collections.abc import Callable from functools import wraps -from typing import Any +from typing import TYPE_CHECKING, Any + +import mcp.types as mt +from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext +from fastmcp.tools.tool import ToolResult from mcvsphere.audit import ( audit_log, @@ -24,6 +30,11 @@ from mcvsphere.permissions import ( get_required_permission, ) +if TYPE_CHECKING: + from fastmcp.server.context import Context + +logger = logging.getLogger(__name__) + def with_permission_check(tool_name: str) -> Callable: """Decorator factory for permission checking and audit logging. @@ -179,3 +190,151 @@ def get_permission_summary() -> dict[str, list[str]]: summary[level].sort() return summary + + +class RBACMiddleware(Middleware): + """FastMCP middleware for Role-Based Access Control. + + Integrates with FastMCP's middleware system to enforce permissions + on every tool call based on OAuth group memberships. + + Example: + mcp = FastMCP("my-server", auth=oauth_provider) + mcp.add_middleware(RBACMiddleware()) + """ + + def _extract_user_from_context( + self, fastmcp_ctx: "Context | None" + ) -> dict[str, Any] | None: + """Extract user claims from FastMCP context. + + Args: + fastmcp_ctx: FastMCP Context object from middleware context. + + Returns: + User claims dict from OAuth token, or None if not authenticated. + """ + if fastmcp_ctx is None: + return None + + try: + # FastMCP stores access token in request_context + if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context: + token = getattr(fastmcp_ctx.request_context, "access_token", None) + if token and hasattr(token, "claims"): + return token.claims + except Exception as e: + logger.debug("Failed to extract user from context: %s", e) + + return None + + def _get_username(self, claims: dict[str, Any] | None) -> str: + """Extract username from OAuth claims. + + Args: + claims: OAuth token claims dict. + + Returns: + Username string, or 'anonymous' if no claims. + """ + if not claims: + return "anonymous" + + for claim in ("preferred_username", "email", "sub"): + if value := claims.get(claim): + return str(value) + + return "unknown" + + def _get_groups(self, claims: dict[str, Any] | None) -> list[str]: + """Extract groups from OAuth claims. + + Args: + claims: OAuth token claims dict. + + Returns: + List of group names, or empty list if no claims. + """ + if not claims: + return [] + + groups = claims.get("groups", []) + if isinstance(groups, list): + return groups + return [] + + async def on_call_tool( + self, + context: MiddlewareContext[mt.CallToolRequestParams], + call_next: CallNext[mt.CallToolRequestParams, ToolResult], + ) -> ToolResult: + """Intercept tool calls to enforce RBAC permissions. + + Args: + context: Middleware context containing tool call params. + call_next: Next handler in the middleware chain. + + Returns: + Tool result if permitted. + + Raises: + PermissionDeniedError: If user lacks required permission. + """ + # Extract tool name and arguments from the request + tool_name = context.message.name + tool_args = context.message.arguments or {} + + # Get user info from OAuth context + claims = self._extract_user_from_context(context.fastmcp_context) + username = self._get_username(claims) + groups = self._get_groups(claims) + + # Set up audit context for this request + set_current_user(claims) + + # Check permission + if not check_permission(tool_name, groups): + required = get_required_permission(tool_name) + logger.warning( + "Permission denied: user=%s groups=%s tool=%s required=%s", + username, + groups, + tool_name, + required.value, + ) + audit_permission_denied(tool_name, tool_args, required.value) + raise PermissionDeniedError(username, tool_name, required) + + # Permission granted - execute tool with timing + start_time = time.perf_counter() + try: + result = await call_next(context) + duration_ms = (time.perf_counter() - start_time) * 1000 + + # Audit successful execution + # ToolResult can be complex, just log that it succeeded + audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms) + + logger.debug( + "Tool executed: user=%s tool=%s duration=%.2fms", + username, + tool_name, + duration_ms, + ) + + return result + + except PermissionDeniedError: + # Re-raise without additional logging + raise + except Exception as e: + duration_ms = (time.perf_counter() - start_time) * 1000 + audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms) + logger.error( + "Tool failed: user=%s tool=%s error=%s duration=%.2fms", + username, + tool_name, + str(e), + duration_ms, + ) + raise diff --git a/src/mcvsphere/server.py b/src/mcvsphere/server.py index d998286..499c058 100644 --- a/src/mcvsphere/server.py +++ b/src/mcvsphere/server.py @@ -9,6 +9,7 @@ from fastmcp import FastMCP from mcvsphere.auth import create_auth_provider from mcvsphere.config import Settings, get_settings from mcvsphere.connection import VMwareConnection +from mcvsphere.middleware import RBACMiddleware from mcvsphere.mixins import ( ConsoleMixin, DiskManagementMixin, @@ -68,6 +69,11 @@ def create_server(settings: Settings | None = None) -> FastMCP: auth=auth, ) + # Add RBAC middleware when OAuth is enabled + if settings.oauth_enabled: + mcp.add_middleware(RBACMiddleware()) + logger.info("RBAC middleware enabled - permissions enforced via OAuth groups") + # Create shared VMware connection logger.info("Connecting to VMware vCenter/ESXi...") conn = VMwareConnection(settings) @@ -137,8 +143,9 @@ def run_server(config_path: Path | None = None) -> None: ) if settings.oauth_enabled: print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr) + print("RBAC: ENABLED - permissions enforced via groups", file=sys.stderr) else: - print("OAuth: disabled", file=sys.stderr) + print("OAuth: disabled (single-user mode)", file=sys.stderr) print("─" * 40, file=sys.stderr) # Create and run server