feat: implement RBAC middleware for OAuth group-based permissions

Add RBACMiddleware that integrates with FastMCP's middleware system to
enforce role-based access control on all tool calls:

- Intercepts every tool call via on_call_tool() hook
- Extracts user groups from OAuth token claims
- Checks permissions using existing permissions.py mappings
- Logs all tool invocations with user identity via audit.py
- Denies access with clear PermissionDeniedError when unauthorized

Permission levels (from permissions.py):
- READ_ONLY: view operations (vsphere-readers)
- POWER_OPS: power/snapshot ops (vsphere-operators)
- VM_LIFECYCLE: create/delete VMs (vsphere-admins)
- HOST_ADMIN: ESXi host management (vsphere-host-admins)
- FULL_ADMIN: guest ops, services (vsphere-super-admins)

Middleware only enabled when OAuth is active (OAUTH_ENABLED=true).
STDIO mode continues to work without permission checking.
This commit is contained in:
Ryan Malloy 2025-12-27 07:37:26 -07:00
parent 0e29fea857
commit 6159098963
2 changed files with 168 additions and 2 deletions

View File

@ -3,12 +3,18 @@
Provides decorators and hooks for wrapping tool execution with:
- OAuth permission validation
- Audit logging with user identity
- FastMCP middleware integration for RBAC
"""
import logging
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
from typing import TYPE_CHECKING, Any
import mcp.types as mt
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import ToolResult
from mcvsphere.audit import (
audit_log,
@ -24,6 +30,11 @@ from mcvsphere.permissions import (
get_required_permission,
)
if TYPE_CHECKING:
from fastmcp.server.context import Context
logger = logging.getLogger(__name__)
def with_permission_check(tool_name: str) -> Callable:
"""Decorator factory for permission checking and audit logging.
@ -179,3 +190,151 @@ def get_permission_summary() -> dict[str, list[str]]:
summary[level].sort()
return summary
class RBACMiddleware(Middleware):
"""FastMCP middleware for Role-Based Access Control.
Integrates with FastMCP's middleware system to enforce permissions
on every tool call based on OAuth group memberships.
Example:
mcp = FastMCP("my-server", auth=oauth_provider)
mcp.add_middleware(RBACMiddleware())
"""
def _extract_user_from_context(
self, fastmcp_ctx: "Context | None"
) -> dict[str, Any] | None:
"""Extract user claims from FastMCP context.
Args:
fastmcp_ctx: FastMCP Context object from middleware context.
Returns:
User claims dict from OAuth token, or None if not authenticated.
"""
if fastmcp_ctx is None:
return None
try:
# FastMCP stores access token in request_context
if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context:
token = getattr(fastmcp_ctx.request_context, "access_token", None)
if token and hasattr(token, "claims"):
return token.claims
except Exception as e:
logger.debug("Failed to extract user from context: %s", e)
return None
def _get_username(self, claims: dict[str, Any] | None) -> str:
"""Extract username from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
Username string, or 'anonymous' if no claims.
"""
if not claims:
return "anonymous"
for claim in ("preferred_username", "email", "sub"):
if value := claims.get(claim):
return str(value)
return "unknown"
def _get_groups(self, claims: dict[str, Any] | None) -> list[str]:
"""Extract groups from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
List of group names, or empty list if no claims.
"""
if not claims:
return []
groups = claims.get("groups", [])
if isinstance(groups, list):
return groups
return []
async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Intercept tool calls to enforce RBAC permissions.
Args:
context: Middleware context containing tool call params.
call_next: Next handler in the middleware chain.
Returns:
Tool result if permitted.
Raises:
PermissionDeniedError: If user lacks required permission.
"""
# Extract tool name and arguments from the request
tool_name = context.message.name
tool_args = context.message.arguments or {}
# Get user info from OAuth context
claims = self._extract_user_from_context(context.fastmcp_context)
username = self._get_username(claims)
groups = self._get_groups(claims)
# Set up audit context for this request
set_current_user(claims)
# Check permission
if not check_permission(tool_name, groups):
required = get_required_permission(tool_name)
logger.warning(
"Permission denied: user=%s groups=%s tool=%s required=%s",
username,
groups,
tool_name,
required.value,
)
audit_permission_denied(tool_name, tool_args, required.value)
raise PermissionDeniedError(username, tool_name, required)
# Permission granted - execute tool with timing
start_time = time.perf_counter()
try:
result = await call_next(context)
duration_ms = (time.perf_counter() - start_time) * 1000
# Audit successful execution
# ToolResult can be complex, just log that it succeeded
audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms)
logger.debug(
"Tool executed: user=%s tool=%s duration=%.2fms",
username,
tool_name,
duration_ms,
)
return result
except PermissionDeniedError:
# Re-raise without additional logging
raise
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms)
logger.error(
"Tool failed: user=%s tool=%s error=%s duration=%.2fms",
username,
tool_name,
str(e),
duration_ms,
)
raise

View File

@ -9,6 +9,7 @@ from fastmcp import FastMCP
from mcvsphere.auth import create_auth_provider
from mcvsphere.config import Settings, get_settings
from mcvsphere.connection import VMwareConnection
from mcvsphere.middleware import RBACMiddleware
from mcvsphere.mixins import (
ConsoleMixin,
DiskManagementMixin,
@ -68,6 +69,11 @@ def create_server(settings: Settings | None = None) -> FastMCP:
auth=auth,
)
# Add RBAC middleware when OAuth is enabled
if settings.oauth_enabled:
mcp.add_middleware(RBACMiddleware())
logger.info("RBAC middleware enabled - permissions enforced via OAuth groups")
# Create shared VMware connection
logger.info("Connecting to VMware vCenter/ESXi...")
conn = VMwareConnection(settings)
@ -137,8 +143,9 @@ def run_server(config_path: Path | None = None) -> None:
)
if settings.oauth_enabled:
print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr)
print("RBAC: ENABLED - permissions enforced via groups", file=sys.stderr)
else:
print("OAuth: disabled", file=sys.stderr)
print("OAuth: disabled (single-user mode)", file=sys.stderr)
print("" * 40, file=sys.stderr)
# Create and run server