feat: implement RBAC middleware for OAuth group-based permissions
Add RBACMiddleware that integrates with FastMCP's middleware system to enforce role-based access control on all tool calls: - Intercepts every tool call via on_call_tool() hook - Extracts user groups from OAuth token claims - Checks permissions using existing permissions.py mappings - Logs all tool invocations with user identity via audit.py - Denies access with clear PermissionDeniedError when unauthorized Permission levels (from permissions.py): - READ_ONLY: view operations (vsphere-readers) - POWER_OPS: power/snapshot ops (vsphere-operators) - VM_LIFECYCLE: create/delete VMs (vsphere-admins) - HOST_ADMIN: ESXi host management (vsphere-host-admins) - FULL_ADMIN: guest ops, services (vsphere-super-admins) Middleware only enabled when OAuth is active (OAUTH_ENABLED=true). STDIO mode continues to work without permission checking.
This commit is contained in:
parent
0e29fea857
commit
6159098963
@ -3,12 +3,18 @@
|
|||||||
Provides decorators and hooks for wrapping tool execution with:
|
Provides decorators and hooks for wrapping tool execution with:
|
||||||
- OAuth permission validation
|
- OAuth permission validation
|
||||||
- Audit logging with user identity
|
- Audit logging with user identity
|
||||||
|
- FastMCP middleware integration for RBAC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
import mcp.types as mt
|
||||||
|
from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext
|
||||||
|
from fastmcp.tools.tool import ToolResult
|
||||||
|
|
||||||
from mcvsphere.audit import (
|
from mcvsphere.audit import (
|
||||||
audit_log,
|
audit_log,
|
||||||
@ -24,6 +30,11 @@ from mcvsphere.permissions import (
|
|||||||
get_required_permission,
|
get_required_permission,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from fastmcp.server.context import Context
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def with_permission_check(tool_name: str) -> Callable:
|
def with_permission_check(tool_name: str) -> Callable:
|
||||||
"""Decorator factory for permission checking and audit logging.
|
"""Decorator factory for permission checking and audit logging.
|
||||||
@ -179,3 +190,151 @@ def get_permission_summary() -> dict[str, list[str]]:
|
|||||||
summary[level].sort()
|
summary[level].sort()
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
class RBACMiddleware(Middleware):
|
||||||
|
"""FastMCP middleware for Role-Based Access Control.
|
||||||
|
|
||||||
|
Integrates with FastMCP's middleware system to enforce permissions
|
||||||
|
on every tool call based on OAuth group memberships.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
mcp = FastMCP("my-server", auth=oauth_provider)
|
||||||
|
mcp.add_middleware(RBACMiddleware())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _extract_user_from_context(
|
||||||
|
self, fastmcp_ctx: "Context | None"
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Extract user claims from FastMCP context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fastmcp_ctx: FastMCP Context object from middleware context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User claims dict from OAuth token, or None if not authenticated.
|
||||||
|
"""
|
||||||
|
if fastmcp_ctx is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# FastMCP stores access token in request_context
|
||||||
|
if hasattr(fastmcp_ctx, "request_context") and fastmcp_ctx.request_context:
|
||||||
|
token = getattr(fastmcp_ctx.request_context, "access_token", None)
|
||||||
|
if token and hasattr(token, "claims"):
|
||||||
|
return token.claims
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Failed to extract user from context: %s", e)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_username(self, claims: dict[str, Any] | None) -> str:
|
||||||
|
"""Extract username from OAuth claims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
claims: OAuth token claims dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Username string, or 'anonymous' if no claims.
|
||||||
|
"""
|
||||||
|
if not claims:
|
||||||
|
return "anonymous"
|
||||||
|
|
||||||
|
for claim in ("preferred_username", "email", "sub"):
|
||||||
|
if value := claims.get(claim):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
def _get_groups(self, claims: dict[str, Any] | None) -> list[str]:
|
||||||
|
"""Extract groups from OAuth claims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
claims: OAuth token claims dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of group names, or empty list if no claims.
|
||||||
|
"""
|
||||||
|
if not claims:
|
||||||
|
return []
|
||||||
|
|
||||||
|
groups = claims.get("groups", [])
|
||||||
|
if isinstance(groups, list):
|
||||||
|
return groups
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def on_call_tool(
|
||||||
|
self,
|
||||||
|
context: MiddlewareContext[mt.CallToolRequestParams],
|
||||||
|
call_next: CallNext[mt.CallToolRequestParams, ToolResult],
|
||||||
|
) -> ToolResult:
|
||||||
|
"""Intercept tool calls to enforce RBAC permissions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: Middleware context containing tool call params.
|
||||||
|
call_next: Next handler in the middleware chain.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool result if permitted.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionDeniedError: If user lacks required permission.
|
||||||
|
"""
|
||||||
|
# Extract tool name and arguments from the request
|
||||||
|
tool_name = context.message.name
|
||||||
|
tool_args = context.message.arguments or {}
|
||||||
|
|
||||||
|
# Get user info from OAuth context
|
||||||
|
claims = self._extract_user_from_context(context.fastmcp_context)
|
||||||
|
username = self._get_username(claims)
|
||||||
|
groups = self._get_groups(claims)
|
||||||
|
|
||||||
|
# Set up audit context for this request
|
||||||
|
set_current_user(claims)
|
||||||
|
|
||||||
|
# Check permission
|
||||||
|
if not check_permission(tool_name, groups):
|
||||||
|
required = get_required_permission(tool_name)
|
||||||
|
logger.warning(
|
||||||
|
"Permission denied: user=%s groups=%s tool=%s required=%s",
|
||||||
|
username,
|
||||||
|
groups,
|
||||||
|
tool_name,
|
||||||
|
required.value,
|
||||||
|
)
|
||||||
|
audit_permission_denied(tool_name, tool_args, required.value)
|
||||||
|
raise PermissionDeniedError(username, tool_name, required)
|
||||||
|
|
||||||
|
# Permission granted - execute tool with timing
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
try:
|
||||||
|
result = await call_next(context)
|
||||||
|
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
# Audit successful execution
|
||||||
|
# ToolResult can be complex, just log that it succeeded
|
||||||
|
audit_log(tool_name, tool_args, result="success", duration_ms=duration_ms)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Tool executed: user=%s tool=%s duration=%.2fms",
|
||||||
|
username,
|
||||||
|
tool_name,
|
||||||
|
duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except PermissionDeniedError:
|
||||||
|
# Re-raise without additional logging
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
audit_log(tool_name, tool_args, error=str(e), duration_ms=duration_ms)
|
||||||
|
logger.error(
|
||||||
|
"Tool failed: user=%s tool=%s error=%s duration=%.2fms",
|
||||||
|
username,
|
||||||
|
tool_name,
|
||||||
|
str(e),
|
||||||
|
duration_ms,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from fastmcp import FastMCP
|
|||||||
from mcvsphere.auth import create_auth_provider
|
from mcvsphere.auth import create_auth_provider
|
||||||
from mcvsphere.config import Settings, get_settings
|
from mcvsphere.config import Settings, get_settings
|
||||||
from mcvsphere.connection import VMwareConnection
|
from mcvsphere.connection import VMwareConnection
|
||||||
|
from mcvsphere.middleware import RBACMiddleware
|
||||||
from mcvsphere.mixins import (
|
from mcvsphere.mixins import (
|
||||||
ConsoleMixin,
|
ConsoleMixin,
|
||||||
DiskManagementMixin,
|
DiskManagementMixin,
|
||||||
@ -68,6 +69,11 @@ def create_server(settings: Settings | None = None) -> FastMCP:
|
|||||||
auth=auth,
|
auth=auth,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add RBAC middleware when OAuth is enabled
|
||||||
|
if settings.oauth_enabled:
|
||||||
|
mcp.add_middleware(RBACMiddleware())
|
||||||
|
logger.info("RBAC middleware enabled - permissions enforced via OAuth groups")
|
||||||
|
|
||||||
# Create shared VMware connection
|
# Create shared VMware connection
|
||||||
logger.info("Connecting to VMware vCenter/ESXi...")
|
logger.info("Connecting to VMware vCenter/ESXi...")
|
||||||
conn = VMwareConnection(settings)
|
conn = VMwareConnection(settings)
|
||||||
@ -137,8 +143,9 @@ def run_server(config_path: Path | None = None) -> None:
|
|||||||
)
|
)
|
||||||
if settings.oauth_enabled:
|
if settings.oauth_enabled:
|
||||||
print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr)
|
print(f"OAuth: ENABLED via {settings.oauth_issuer_url}", file=sys.stderr)
|
||||||
|
print("RBAC: ENABLED - permissions enforced via groups", file=sys.stderr)
|
||||||
else:
|
else:
|
||||||
print("OAuth: disabled", file=sys.stderr)
|
print("OAuth: disabled (single-user mode)", file=sys.stderr)
|
||||||
print("─" * 40, file=sys.stderr)
|
print("─" * 40, file=sys.stderr)
|
||||||
|
|
||||||
# Create and run server
|
# Create and run server
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user