"""Cursor-based pagination system for large MCP responses. Provides efficient pagination with grep filtering, session isolation, and TTL-based cursor expiration. """ import hashlib import json import re import time from collections import OrderedDict from dataclasses import dataclass, field from threading import Lock from typing import Any, Dict, List, Optional, Tuple from ..config import get_config from .filtering import estimate_and_guard, project_fields # ReDoS Protection Configuration MAX_GREP_PATTERN_LENGTH = 500 MAX_GREP_REPETITION_OPS = 15 MAX_GREP_RECURSION_DEPTH = 10 # Token estimation (roughly 4 chars per token) TOKEN_ESTIMATION_RATIO = 4.0 def compile_safe_pattern(pattern: str, flags: int = 0) -> re.Pattern: """Compile regex pattern with ReDoS protection. Validates pattern to prevent catastrophic backtracking attacks. Args: pattern: Regex pattern string flags: Regex compilation flags Returns: Compiled regex pattern Raises: ValueError: If pattern fails safety validation """ if not pattern: raise ValueError("Empty pattern") if len(pattern) > MAX_GREP_PATTERN_LENGTH: raise ValueError( f"Pattern too long ({len(pattern)} chars, max {MAX_GREP_PATTERN_LENGTH}). " "Consider using a simpler pattern." ) # Count repetition operators repetition_ops = pattern.count("*") + pattern.count("+") + pattern.count("?") repetition_ops += len(re.findall(r"\{[0-9,]+\}", pattern)) if repetition_ops > MAX_GREP_REPETITION_OPS: raise ValueError( f"Pattern has too many repetition operators ({repetition_ops}, " f"max {MAX_GREP_REPETITION_OPS}). Consider simplifying." ) # Check for dangerous nested quantifiers dangerous_patterns = [ r"\([^)]*[*+][^)]*\)[*+]", # (a+)+ or (a*)* r"\([^)]*[*+][^)]*\)\{", # (a+){n,m} ] for dangerous in dangerous_patterns: if re.search(dangerous, pattern): raise ValueError( "Pattern contains nested quantifiers which could cause " "exponential backtracking. Consider simplifying." ) try: return re.compile(pattern, flags) except re.error as e: raise ValueError(f"Invalid regex pattern: {e}") @dataclass class CursorState: """Represents the state of a paginated query with session isolation.""" cursor_id: str session_id: str tool_name: str query_hash: str data: List[Any] total_count: int filtered_count: int current_offset: int = 0 page_size: int = 50 grep_pattern: Optional[str] = None grep_flags: int = 0 created_at: float = field(default_factory=time.time) last_accessed: float = field(default_factory=time.time) @property def is_expired(self) -> bool: config = get_config() return time.time() - self.last_accessed > config.cursor_ttl_seconds @property def has_more(self) -> bool: return self.current_offset + self.page_size < self.filtered_count @property def current_page(self) -> int: return (self.current_offset // self.page_size) + 1 @property def total_pages(self) -> int: return max(1, (self.filtered_count + self.page_size - 1) // self.page_size) @property def ttl_remaining(self) -> int: config = get_config() return max(0, int(config.cursor_ttl_seconds - (time.time() - self.last_accessed))) def verify_session(self, session_id: str) -> bool: """Verify cursor belongs to requesting session.""" return self.session_id == session_id class CursorManager: """Thread-safe cursor manager with TTL-based expiration and session isolation.""" def __init__(self): self._cursors: OrderedDict[str, CursorState] = OrderedDict() self._session_cursors: Dict[str, set] = {} self._lock = Lock() def _generate_cursor_id(self, query_hash: str, session_id: str) -> str: """Generate a unique cursor ID.""" unique = f"{session_id}-{query_hash}-{time.time()}-{id(self)}" return hashlib.sha256(unique.encode()).hexdigest()[:16] def _cleanup_expired(self) -> None: """Remove expired cursors (call while holding lock).""" config = get_config() expired = [cid for cid, state in self._cursors.items() if state.is_expired] for cid in expired: state = self._cursors[cid] if state.session_id in self._session_cursors: self._session_cursors[state.session_id].discard(cid) del self._cursors[cid] # LRU eviction while len(self._cursors) > config.max_cursors_per_session: oldest_id, oldest_state = self._cursors.popitem(last=False) if oldest_state.session_id in self._session_cursors: self._session_cursors[oldest_state.session_id].discard(oldest_id) def create_cursor( self, data: List[Any], query_params: Dict[str, Any], tool_name: str = "unknown", session_id: str = "default", grep_pattern: Optional[str] = None, grep_flags: int = 0, page_size: int = 50, ) -> Tuple[str, CursorState]: """Create a new cursor for paginated results. Args: data: The full result set to paginate query_params: Original query parameters (for hashing) tool_name: Name of tool creating cursor session_id: Session identifier for isolation grep_pattern: Optional regex pattern to filter results grep_flags: Regex flags page_size: Items per page Returns: Tuple of (cursor_id, cursor_state) """ config = get_config() # Apply grep filtering filtered_data = data if grep_pattern: pattern = compile_safe_pattern(grep_pattern, grep_flags) filtered_data = [ item for item in data if self._matches_grep(item, pattern) ] # Create query hash query_hash = hashlib.md5( json.dumps(query_params, sort_keys=True, default=str).encode() ).hexdigest()[:12] with self._lock: self._cleanup_expired() cursor_id = self._generate_cursor_id(query_hash, session_id) state = CursorState( cursor_id=cursor_id, session_id=session_id, tool_name=tool_name, query_hash=query_hash, data=filtered_data, total_count=len(data), filtered_count=len(filtered_data), page_size=min(page_size, config.max_page_size), grep_pattern=grep_pattern, grep_flags=grep_flags, ) self._cursors[cursor_id] = state if session_id not in self._session_cursors: self._session_cursors[session_id] = set() self._session_cursors[session_id].add(cursor_id) return cursor_id, state def get_cursor( self, cursor_id: str, session_id: Optional[str] = None ) -> Optional[CursorState]: """Retrieve a cursor by ID, optionally validating session.""" with self._lock: self._cleanup_expired() if cursor_id not in self._cursors: return None state = self._cursors[cursor_id] if state.is_expired: del self._cursors[cursor_id] if state.session_id in self._session_cursors: self._session_cursors[state.session_id].discard(cursor_id) return None if session_id and not state.verify_session(session_id): return None state.last_accessed = time.time() self._cursors.move_to_end(cursor_id) return state def advance_cursor( self, cursor_id: str, session_id: Optional[str] = None ) -> Optional[CursorState]: """Advance cursor to next page.""" with self._lock: state = self._cursors.get(cursor_id) if not state or state.is_expired: return None if session_id and not state.verify_session(session_id): return None state.current_offset += state.page_size state.last_accessed = time.time() self._cursors.move_to_end(cursor_id) return state def delete_cursor( self, cursor_id: str, session_id: Optional[str] = None ) -> bool: """Explicitly delete a cursor.""" with self._lock: if cursor_id not in self._cursors: return False state = self._cursors[cursor_id] if session_id and not state.verify_session(session_id): return False if state.session_id in self._session_cursors: self._session_cursors[state.session_id].discard(cursor_id) del self._cursors[cursor_id] return True def delete_session_cursors(self, session_id: str) -> int: """Delete all cursors for a session.""" with self._lock: if session_id not in self._session_cursors: return 0 cursor_ids = list(self._session_cursors[session_id]) count = 0 for cid in cursor_ids: if cid in self._cursors: del self._cursors[cid] count += 1 del self._session_cursors[session_id] return count def get_page(self, state: CursorState) -> List[Any]: """Get current page of data from cursor state.""" start = state.current_offset end = start + state.page_size return state.data[start:end] def _matches_grep( self, item: Any, pattern: re.Pattern, depth: int = 0 ) -> bool: """Check if an item matches the grep pattern. Searches through string representations of dict values, list items, or the item itself. """ if depth > MAX_GREP_RECURSION_DEPTH: return False if isinstance(item, dict): for value in item.values(): if isinstance(value, str) and pattern.search(value): return True elif isinstance(value, (int, float)): if pattern.search(str(value)): return True elif isinstance(value, dict): if self._matches_grep(value, pattern, depth + 1): return True elif isinstance(value, (list, tuple)): if self._matches_grep(value, pattern, depth + 1): return True return False elif isinstance(item, (list, tuple)): return any(self._matches_grep(i, pattern, depth + 1) for i in item) elif isinstance(item, str): return bool(pattern.search(item)) else: return bool(pattern.search(str(item))) def list_cursors(self, session_id: Optional[str] = None) -> List[Dict[str, Any]]: """List active cursors, optionally filtered by session.""" with self._lock: self._cleanup_expired() return [ { "cursor_id": cid, "session_id": state.session_id, "tool_name": state.tool_name, "total_count": state.total_count, "filtered_count": state.filtered_count, "current_page": state.current_page, "total_pages": state.total_pages, "current_offset": state.current_offset, "page_size": state.page_size, "has_more": state.has_more, "grep_pattern": state.grep_pattern, "age_seconds": int(time.time() - state.created_at), "ttl_remaining": state.ttl_remaining, } for cid, state in self._cursors.items() if session_id is None or state.session_id == session_id ] def get_stats(self) -> Dict[str, Any]: """Get cursor manager statistics.""" config = get_config() with self._lock: self._cleanup_expired() return { "total_cursors": len(self._cursors), "total_sessions": len(self._session_cursors), "max_cache_size": config.max_cursors_per_session, "ttl_seconds": config.cursor_ttl_seconds, "cursors_per_session": { sid: len(cids) for sid, cids in self._session_cursors.items() }, } # Global cursor manager instance _cursor_manager: Optional[CursorManager] = None def get_cursor_manager() -> CursorManager: """Get the global cursor manager instance.""" global _cursor_manager if _cursor_manager is None: _cursor_manager = CursorManager() return _cursor_manager def estimate_tokens(data: List[Any]) -> int: """Estimate token count for a list of items.""" text = json.dumps(data, default=str) return int(len(text) / TOKEN_ESTIMATION_RATIO) def paginate_response( data: List[Any], query_params: Dict[str, Any], tool_name: str = "unknown", session_id: str = "default", page_size: int = 50, grep: Optional[str] = None, grep_ignorecase: bool = True, return_all: bool = False, fields: Optional[List[str]] = None, ) -> Dict[str, Any]: """Create a paginated response with optional grep filtering and field projection. Args: data: Full result list to paginate query_params: Original query parameters (for cursor creation) tool_name: Name of the tool creating this response session_id: Session identifier for cursor isolation page_size: Items per page (default: 50, max: 500) grep: Optional regex pattern to filter results grep_ignorecase: Case-insensitive grep (default: True) return_all: Bypass pagination and return all results (with budget guard) fields: Optional list of field names to project (jq-style) Returns: dict with pagination metadata and results """ config = get_config() cursor_manager = get_cursor_manager() grep_flags = re.IGNORECASE if grep_ignorecase else 0 # Handle return_all bypass if return_all: filtered_data = data if grep: try: pattern = compile_safe_pattern(grep, grep_flags) filtered_data = [ item for item in data if cursor_manager._matches_grep(item, pattern) ] except ValueError as e: return { "success": False, "error": {"code": "INVALID_GREP_PATTERN", "message": str(e)}, "timestamp": int(time.time() * 1000), } # Apply field projection before size estimation if fields: filtered_data = project_fields(filtered_data, fields) # Check token budget — return guard if exceeded guard = estimate_and_guard( data=filtered_data, tool_name=tool_name, query_hints=query_params, ) if guard is not None: return guard estimated_tokens = estimate_tokens(filtered_data) warning = None if estimated_tokens > 50000: warning = f"EXTREMELY LARGE response (~{estimated_tokens:,} tokens)" elif estimated_tokens > 20000: warning = f"VERY LARGE response (~{estimated_tokens:,} tokens)" elif estimated_tokens > config.large_response_threshold: warning = f"Large response (~{estimated_tokens:,} tokens)" return { "success": True, "result": filtered_data, "pagination": { "bypassed": True, "total_count": len(data), "filtered_count": len(filtered_data), "grep_pattern": grep, "fields_projected": fields, "estimated_tokens": estimated_tokens, "warning": warning, }, "timestamp": int(time.time() * 1000), } # Normal pagination flow — apply field projection before cursoring paginated_data = project_fields(data, fields) if fields else data try: cursor_id, state = cursor_manager.create_cursor( data=paginated_data, query_params=query_params, tool_name=tool_name, session_id=session_id, grep_pattern=grep, grep_flags=grep_flags, page_size=min(page_size, config.max_page_size), ) except ValueError as e: return { "success": False, "error": {"code": "INVALID_GREP_PATTERN", "message": str(e)}, "timestamp": int(time.time() * 1000), } current_page = cursor_manager.get_page(state) response_cursor = cursor_id if state.has_more else None response = { "success": True, "result": current_page, "pagination": { "cursor_id": response_cursor, "session_id": session_id, "total_count": state.total_count, "filtered_count": state.filtered_count, "page_size": state.page_size, "current_page": state.current_page, "total_pages": state.total_pages, "has_more": state.has_more, "grep_pattern": grep, "items_returned": len(current_page), }, "timestamp": int(time.time() * 1000), } # Add LLM-friendly continuation message if state.has_more: remaining = state.filtered_count - (state.current_page * state.page_size) response["_message"] = ( f"Showing {len(current_page)} of {state.filtered_count} items " f"(page {state.current_page}/{state.total_pages}). " f"To get the next {min(state.page_size, remaining)} items, call: " f"cursor_next(cursor_id='{cursor_id}')" ) else: response["_message"] = ( f"Complete: {len(current_page)} items returned (all results)" ) return response