feat: implement RBAC middleware for OAuth group-based permissions

Add RBACMiddleware that integrates with FastMCP's middleware system to
enforce role-based access control on all tool calls:

- Intercepts every tool call via on_call_tool() hook
- Extracts user groups from OAuth token claims
- Checks permissions using existing permissions.py mappings
- Logs all tool invocations with user identity via audit.py
- Denies access with clear PermissionDeniedError when unauthorized

Permission levels (from permissions.py):
- READ_ONLY: view operations (vsphere-readers)
- POWER_OPS: power/snapshot ops (vsphere-operators)
- VM_LIFECYCLE: create/delete VMs (vsphere-admins)
- HOST_ADMIN: ESXi host management (vsphere-host-admins)
- FULL_ADMIN: guest ops, services (vsphere-super-admins)

Middleware only enabled when OAuth is active (OAUTH_ENABLED=true).
STDIO mode continues to work without permission checking.
This commit is contained in:
Ryan Malloy 2025-12-27 07:37:26 -07:00
parent 0e29fea857
commit 6159098963
2 changed files with 168 additions and 2 deletions

View File

@ -3,12 +3,18 @@
Provides decorators and hooks for wrapping tool execution with: Provides decorators and hooks for wrapping tool execution with:
- OAuth permission validation - OAuth permission validation
- Audit logging with user identity - Audit logging with user identity
- FastMCP middleware integration for RBAC
""" """
import logging
import time import time
from collections.abc import Callable from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Any from typing import TYPE_CHECKING, Any
import mcp.types as mt
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
from fastmcp.tools.tool import ToolResult
from mcvsphere.audit import ( from mcvsphere.audit import (
audit_log, audit_log,
@ -24,6 +30,11 @@ from mcvsphere.permissions import (
get_required_permission, get_required_permission,
) )
if TYPE_CHECKING:
from fastmcp.server.context import Context
logger = logging.getLogger(__name__)
def with_permission_check(tool_name: str) -> Callable: def with_permission_check(tool_name: str) -> Callable:
"""Decorator factory for permission checking and audit logging. """Decorator factory for permission checking and audit logging.
@ -179,3 +190,151 @@ def get_permission_summary() -> dict[str, list[str]]:
summary[level].sort() summary[level].sort()
return summary return summary
class RBACMiddleware(Middleware):
"""FastMCP middleware for Role-Based Access Control.
Integrates with FastMCP's middleware system to enforce permissions
on every tool call based on OAuth group memberships.
Example:
mcp = FastMCP("my-server", auth=oauth_provider)
mcp.add_middleware(RBACMiddleware())
"""
def _extract_user_from_context(
self, fastmcp_ctx: "Context | None"
) -> dict[str, Any] | None:
"""Extract user claims from FastMCP context.
Args:
fastmcp_ctx: FastMCP Context object from middleware context.
Returns:
User claims dict from OAuth token, or None if not authenticated.
"""
if fastmcp_ctx is None:
return None
try:
# FastMCP stores access token in request_context
if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context:
token = getattr(fastmcp_ctx.request_context, "access_token", None)
if token and hasattr(token, "claims"):
return token.claims
except Exception as e:
logger.debug("Failed to extract user from context: %s", e)
return None
def _get_username(self, claims: dict[str, Any] | None) -> str:
"""Extract username from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
Username string, or 'anonymous' if no claims.
"""
if not claims:
return "anonymous"
for claim in ("preferred_username", "email", "sub"):
if value := claims.get(claim):
return str(value)
return "unknown"
def _get_groups(self, claims: dict[str, Any] | None) -> list[str]:
"""Extract groups from OAuth claims.
Args:
claims: OAuth token claims dict.
Returns:
List of group names, or empty list if no claims.
"""
if not claims:
return []
groups = claims.get("groups", [])
if isinstance(groups, list):
return groups
return []
async def on_call_tool(
self,
context: MiddlewareContext[mt.CallToolRequestParams],
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
) -> ToolResult:
"""Intercept tool calls to enforce RBAC permissions.
Args:
context: Middleware context containing tool call params.
call_next: Next handler in the middleware chain.
Returns:
Tool result if permitted.
Raises:
PermissionDeniedError: If user lacks required permission.
"""
# Extract tool name and arguments from the request
tool_name = context.message.name
tool_args = context.message.arguments or {}
# Get user info from OAuth context
claims = self._extract_user_from_context(context.fastmcp_context)
username = self._get_username(claims)
groups = self._get_groups(claims)
# Set up audit context for this request
set_current_user(claims)
# Check permission
if not check_permission(tool_name, groups):
required = get_required_permission(tool_name)
logger.warning(
"Permission denied: user=%s groups=%s tool=%s required=%s",
username,
groups,
tool_name,
required.value,
)
audit_permission_denied(tool_name, tool_args, required.value)
raise PermissionDeniedError(username, tool_name, required)
# Permission granted - execute tool with timing
start_time = time.perf_counter()
try:
result = await call_next(context)
duration_ms = (time.perf_counter() - start_time) * 1000
# Audit successful execution
# ToolResult can be complex, just log that it succeeded
audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms)
logger.debug(
"Tool executed: user=%s tool=%s duration=%.2fms",
username,
tool_name,
duration_ms,
)
return result
except PermissionDeniedError:
# Re-raise without additional logging
raise
except Exception as e:
duration_ms = (time.perf_counter() - start_time) * 1000
audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms)
logger.error(
"Tool failed: user=%s tool=%s error=%s duration=%.2fms",
username,
tool_name,
str(e),
duration_ms,
)
raise

View File

@ -9,6 +9,7 @@ from fastmcp import FastMCP
from mcvsphere.auth import create_auth_provider from mcvsphere.auth import create_auth_provider
from mcvsphere.config import Settings, get_settings from mcvsphere.config import Settings, get_settings
from mcvsphere.connection import VMwareConnection from mcvsphere.connection import VMwareConnection
from mcvsphere.middleware import RBACMiddleware
from mcvsphere.mixins import ( from mcvsphere.mixins import (
ConsoleMixin, ConsoleMixin,
DiskManagementMixin, DiskManagementMixin,
@ -68,6 +69,11 @@ def create_server(settings: Settings | None = None) -> FastMCP:
auth=auth, auth=auth,
) )
# Add RBAC middleware when OAuth is enabled
if settings.oauth_enabled:
mcp.add_middleware(RBACMiddleware())
logger.info("RBAC middleware enabled - permissions enforced via OAuth groups")
# Create shared VMware connection # Create shared VMware connection
logger.info("Connecting to VMware vCenter/ESXi...") logger.info("Connecting to VMware vCenter/ESXi...")
conn = VMwareConnection(settings) conn = VMwareConnection(settings)
@ -137,8 +143,9 @@ def run_server(config_path: Path | None = None) -> None:
) )
if settings.oauth_enabled: if settings.oauth_enabled:
print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr) print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr)
print("RBAC: ENABLED - permissions enforced via groups", file=sys.stderr)
else: else:
print("OAuth: disabled", file=sys.stderr) print("OAuth: disabled (single-user mode)", file=sys.stderr)
print("" * 40, file=sys.stderr) print("" * 40, file=sys.stderr)
# Create and run server # Create and run server