feat: implement RBAC middleware for OAuth group-based permissions
Add RBACMiddleware that integrates with FastMCP's middleware system to enforce role-based access control on all tool calls: - Intercepts every tool call via on_call_tool() hook - Extracts user groups from OAuth token claims - Checks permissions using existing permissions.py mappings - Logs all tool invocations with user identity via audit.py - Denies access with clear PermissionDeniedError when unauthorized Permission levels (from permissions.py): - READ_ONLY: view operations (vsphere-readers) - POWER_OPS: power/snapshot ops (vsphere-operators) - VM_LIFECYCLE: create/delete VMs (vsphere-admins) - HOST_ADMIN: ESXi host management (vsphere-host-admins) - FULL_ADMIN: guest ops, services (vsphere-super-admins) Middleware only enabled when OAuth is active (OAUTH_ENABLED=true). STDIO mode continues to work without permission checking.
This commit is contained in:
parent
0e29fea857
commit
6159098963
@ -3,12 +3,18 @@
|
||||
Provides decorators and hooks for wrapping tool execution with:
|
||||
- OAuth permission validation
|
||||
- Audit logging with user identity
|
||||
- FastMCP middleware integration for RBAC
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import mcp.types as mt
|
||||
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from mcvsphere.audit import (
|
||||
audit_log,
|
||||
@ -24,6 +30,11 @@ from mcvsphere.permissions import (
|
||||
get_required_permission,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastmcp.server.context import Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def with_permission_check(tool_name: str) -> Callable:
|
||||
"""Decorator factory for permission checking and audit logging.
|
||||
@ -179,3 +190,151 @@ def get_permission_summary() -> dict[str, list[str]]:
|
||||
summary[level].sort()
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
class RBACMiddleware(Middleware):
|
||||
"""FastMCP middleware for Role-Based Access Control.
|
||||
|
||||
Integrates with FastMCP's middleware system to enforce permissions
|
||||
on every tool call based on OAuth group memberships.
|
||||
|
||||
Example:
|
||||
mcp = FastMCP("my-server", auth=oauth_provider)
|
||||
mcp.add_middleware(RBACMiddleware())
|
||||
"""
|
||||
|
||||
def _extract_user_from_context(
|
||||
self, fastmcp_ctx: "Context | None"
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract user claims from FastMCP context.
|
||||
|
||||
Args:
|
||||
fastmcp_ctx: FastMCP Context object from middleware context.
|
||||
|
||||
Returns:
|
||||
User claims dict from OAuth token, or None if not authenticated.
|
||||
"""
|
||||
if fastmcp_ctx is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# FastMCP stores access token in request_context
|
||||
if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context:
|
||||
token = getattr(fastmcp_ctx.request_context, "access_token", None)
|
||||
if token and hasattr(token, "claims"):
|
||||
return token.claims
|
||||
except Exception as e:
|
||||
logger.debug("Failed to extract user from context: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
def _get_username(self, claims: dict[str, Any] | None) -> str:
|
||||
"""Extract username from OAuth claims.
|
||||
|
||||
Args:
|
||||
claims: OAuth token claims dict.
|
||||
|
||||
Returns:
|
||||
Username string, or 'anonymous' if no claims.
|
||||
"""
|
||||
if not claims:
|
||||
return "anonymous"
|
||||
|
||||
for claim in ("preferred_username", "email", "sub"):
|
||||
if value := claims.get(claim):
|
||||
return str(value)
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _get_groups(self, claims: dict[str, Any] | None) -> list[str]:
|
||||
"""Extract groups from OAuth claims.
|
||||
|
||||
Args:
|
||||
claims: OAuth token claims dict.
|
||||
|
||||
Returns:
|
||||
List of group names, or empty list if no claims.
|
||||
"""
|
||||
if not claims:
|
||||
return []
|
||||
|
||||
groups = claims.get("groups", [])
|
||||
if isinstance(groups, list):
|
||||
return groups
|
||||
return []
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext[mt.CallToolRequestParams],
|
||||
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
|
||||
) -> ToolResult:
|
||||
"""Intercept tool calls to enforce RBAC permissions.
|
||||
|
||||
Args:
|
||||
context: Middleware context containing tool call params.
|
||||
call_next: Next handler in the middleware chain.
|
||||
|
||||
Returns:
|
||||
Tool result if permitted.
|
||||
|
||||
Raises:
|
||||
PermissionDeniedError: If user lacks required permission.
|
||||
"""
|
||||
# Extract tool name and arguments from the request
|
||||
tool_name = context.message.name
|
||||
tool_args = context.message.arguments or {}
|
||||
|
||||
# Get user info from OAuth context
|
||||
claims = self._extract_user_from_context(context.fastmcp_context)
|
||||
username = self._get_username(claims)
|
||||
groups = self._get_groups(claims)
|
||||
|
||||
# Set up audit context for this request
|
||||
set_current_user(claims)
|
||||
|
||||
# Check permission
|
||||
if not check_permission(tool_name, groups):
|
||||
required = get_required_permission(tool_name)
|
||||
logger.warning(
|
||||
"Permission denied: user=%s groups=%s tool=%s required=%s",
|
||||
username,
|
||||
groups,
|
||||
tool_name,
|
||||
required.value,
|
||||
)
|
||||
audit_permission_denied(tool_name, tool_args, required.value)
|
||||
raise PermissionDeniedError(username, tool_name, required)
|
||||
|
||||
# Permission granted - execute tool with timing
|
||||
start_time = time.perf_counter()
|
||||
try:
|
||||
result = await call_next(context)
|
||||
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||
|
||||
# Audit successful execution
|
||||
# ToolResult can be complex, just log that it succeeded
|
||||
audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms)
|
||||
|
||||
logger.debug(
|
||||
"Tool executed: user=%s tool=%s duration=%.2fms",
|
||||
username,
|
||||
tool_name,
|
||||
duration_ms,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except PermissionDeniedError:
|
||||
# Re-raise without additional logging
|
||||
raise
|
||||
except Exception as e:
|
||||
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||
audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms)
|
||||
logger.error(
|
||||
"Tool failed: user=%s tool=%s error=%s duration=%.2fms",
|
||||
username,
|
||||
tool_name,
|
||||
str(e),
|
||||
duration_ms,
|
||||
)
|
||||
raise
|
||||
|
||||
@ -9,6 +9,7 @@ from fastmcp import FastMCP
|
||||
from mcvsphere.auth import create_auth_provider
|
||||
from mcvsphere.config import Settings, get_settings
|
||||
from mcvsphere.connection import VMwareConnection
|
||||
from mcvsphere.middleware import RBACMiddleware
|
||||
from mcvsphere.mixins import (
|
||||
ConsoleMixin,
|
||||
DiskManagementMixin,
|
||||
@ -68,6 +69,11 @@ def create_server(settings: Settings | None = None) -> FastMCP:
|
||||
auth=auth,
|
||||
)
|
||||
|
||||
# Add RBAC middleware when OAuth is enabled
|
||||
if settings.oauth_enabled:
|
||||
mcp.add_middleware(RBACMiddleware())
|
||||
logger.info("RBAC middleware enabled - permissions enforced via OAuth groups")
|
||||
|
||||
# Create shared VMware connection
|
||||
logger.info("Connecting to VMware vCenter/ESXi...")
|
||||
conn = VMwareConnection(settings)
|
||||
@ -137,8 +143,9 @@ def run_server(config_path: Path | None = None) -> None:
|
||||
)
|
||||
if settings.oauth_enabled:
|
||||
print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr)
|
||||
print("RBAC: ENABLED - permissions enforced via groups", file=sys.stderr)
|
||||
else:
|
||||
print("OAuth: disabled", file=sys.stderr)
|
||||
print("OAuth: disabled (single-user mode)", file=sys.stderr)
|
||||
print("─" * 40, file=sys.stderr)
|
||||
|
||||
# Create and run server
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user