diff --git a/docs/taskmaster/PLAN.md b/docs/taskmaster/PLAN.md new file mode 100644 index 0000000..95dba8e --- /dev/null +++ b/docs/taskmaster/PLAN.md @@ -0,0 +1,221 @@ +# Taskmaster Execution Plan: mcilspy Code Review Fixes + +## Overview + +33 issues identified in hater-hat code review, organized into 4 parallel workstreams. + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ MERGE ORDER (Sequential) │ +│ security → architecture → performance → testing │ +└─────────────────────────────────────────────────────────────────┘ + │ │ │ │ + ▼ ▼ ▼ ▼ + ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ + │SECURITY │ │ ARCH │ │ PERF │ │ TESTING │ + │ 4 issues│ │ 8 issues│ │ 8 issues│ │ 7 issues│ + │ P1-CRIT │ │ P2-HIGH │ │ P3-MED │ │ P4-LOW │ + └─────────┘ └─────────┘ └─────────┘ └─────────┘ +``` + +--- + +## Domain 1: SECURITY (Priority 1 - Critical) + +**Branch**: `fix/security` +**Estimated Effort**: 2-3 hours +**Blocks**: All other domains should wait for security review patterns + +### Issues to Fix + +| ID | Issue | File:Line | Fix | +|----|-------|-----------|-----| +| S1 | Path traversal - no validation | `server.py:*`, `ilspy_wrapper.py:132` | Add `_validate_assembly_path()` helper that checks: exists, is file, has .dll/.exe extension, resolves to absolute path, optionally check for PE signature | +| S2 | Temp directory race condition | `ilspy_wrapper.py:150-153` | Replace `tempfile.mkdtemp()` with `tempfile.TemporaryDirectory()` context manager | +| S3 | Unbounded subprocess output | `ilspy_wrapper.py:104-107` | Add `MAX_OUTPUT_BYTES = 50_000_000` constant, truncate stdout/stderr if exceeded | +| S4 | No file size limit before loading | `metadata_reader.py:113-115` | Add `MAX_ASSEMBLY_SIZE_MB = 500` check before `dnfile.dnPE()` | + +### Acceptance Criteria +- [ ] All user-provided paths validated before use +- [ ] No temp file leaks possible +- [ ] Memory exhaustion attacks mitigated +- [ ] Unit tests for validation functions + +--- + +## Domain 2: ARCHITECTURE (Priority 2 - High) + +**Branch**: `fix/architecture` +**Estimated Effort**: 3-4 hours +**Depends On**: Security (for validation patterns) + +### Issues to Fix + +| ID | Issue | File:Line | Fix | +|----|-------|-----------|-----| +| A1 | Duplicated PATH discovery | `server.py:99-123`, `ilspy_wrapper.py:41-75` | Extract to `src/mcilspy/utils.py` → `find_ilspycmd_path()` | +| A2 | Mixed dataclass/Pydantic models | `metadata_reader.py` vs `models.py` | Convert metadata_reader dataclasses to Pydantic models in `models.py` | +| A3 | Fragile lifespan context access | `server.py:51-80` | Simplify to module-level `_wrapper: ILSpyWrapper | None = None` with proper locking | +| A4 | Stateless wrapper pretending to be stateful | `ilspy_wrapper.py` | Keep as-is but document why (caches ilspycmd_path lookup) OR convert to module functions | +| A5 | Inconsistent async/sync | `metadata_reader.py` | Document as "CPU-bound sync" in docstring, add note about thread pool for heavy loads | +| A6 | Magic numbers scattered | Throughout | Create `src/mcilspy/constants.py` with all limits/timeouts | +| A7 | Repeated regex compilation | `server.py` (6 places) | Add `_compile_search_pattern(pattern, case_sensitive, use_regex)` helper | +| A8 | Language version validation missing | `server.py:578` | Add try/except with helpful error listing valid versions | + +### Acceptance Criteria +- [ ] Single source of truth for PATH discovery +- [ ] Consistent data model layer (all Pydantic) +- [ ] All magic numbers in constants.py +- [ ] Helper functions reduce code duplication + +--- + +## Domain 3: PERFORMANCE (Priority 3 - Medium) + +**Branch**: `fix/performance` +**Estimated Effort**: 4-5 hours +**Depends On**: Architecture (for new constants/utils) + +### Issues to Fix + +| ID | Issue | File:Line | Fix | +|----|-------|-----------|-----| +| P1 | search_strings decompiles entire assembly | `server.py:948-954` | Use dnfile's `pe.net.user_strings` to search string heap directly - 100x faster | +| P2 | No result pagination | All list_* tools | Add `max_results: int = 1000` and `offset: int = 0` params, return `has_more` flag | +| P3 | List conversion instead of generators | `metadata_reader.py:289` | Use `enumerate()` directly on iterator where possible | +| P4 | No caching of decompilation | `ilspy_wrapper.py` | Add optional LRU cache keyed by (path, mtime, type_name) | +| P5 | Silent failures in platform detection | `server.py:195-230` | Log warnings for permission errors, return reason in result | +| P6 | Generic exception catches | Throughout | Replace with specific exceptions, preserve stack traces | +| P7 | MetadataReader __exit__ type mismatch | `metadata_reader.py:595-597` | Fix return type, remove unnecessary `return False` | +| P8 | No assembly validation | `ilspy_wrapper.py` | Add PE signature check (MZ header + PE\0\0) before subprocess | + +### Acceptance Criteria +- [ ] search_strings uses string heap (benchmark: 10x improvement) +- [ ] All list tools support pagination +- [ ] No silent swallowed exceptions +- [ ] PE validation prevents wasted subprocess calls + +--- + +## Domain 4: TESTING (Priority 4 - Enhancement) + +**Branch**: `fix/testing` +**Estimated Effort**: 5-6 hours +**Depends On**: All other domains (tests should cover new code) + +### Issues to Fix + +| ID | Issue | Fix | +|----|-------|-----| +| T1 | No integration tests | Create `tests/integration/` with real assembly tests. Use a small test .dll checked into repo | +| T2 | No MCP tool tests | Add `tests/test_server_tools.py` testing each `@mcp.tool()` function | +| T3 | No error path tests | Add tests for: regex compilation failure, ilspycmd not found, ctx.info() failure | +| T4 | No concurrency tests | Add `tests/test_concurrency.py` with parallel tool invocations | +| T5 | Missing docstring validation | Add test that all public functions have docstrings (using AST) | +| T6 | No cancel/progress tests | Test timeout behavior, verify progress reporting works | +| T7 | Add test .NET assembly | Create minimal C# project, compile to .dll, check into `tests/fixtures/` | + +### Acceptance Criteria +- [ ] Integration tests with real ilspycmd calls +- [ ] 80%+ code coverage (currently ~40% estimated) +- [ ] All error paths tested +- [ ] CI runs integration tests + +--- + +## Deferred (Won't Fix This Sprint) + +These issues are valid but lower priority: + +| ID | Issue | Reason to Defer | +|----|-------|-----------------| +| D1 | No cancel/abort for decompilation | Requires MCP protocol support for cancellation | +| D2 | No progress reporting | Needs ilspycmd changes or parsing stdout in real-time | +| D3 | No resource extraction | Feature request, not bug - add to backlog | +| D4 | install_ilspy sudo handling | Edge case - document limitation instead | +| D5 | No dry-run for installation | Nice-to-have, not critical | +| D6 | Error messages expose paths | Low risk for MCP (local tool) | + +--- + +## Execution Commands + +### Setup Worktrees (Run Once) + +```bash +# Create worktrees for parallel development +git worktree add ../mcilspy-security fix/security -b fix/security +git worktree add ../mcilspy-arch fix/architecture -b fix/architecture +git worktree add ../mcilspy-perf fix/performance -b fix/performance +git worktree add ../mcilspy-testing fix/testing -b fix/testing +``` + +### Task Master Commands + +```bash +# Security Task Master +cd ../mcilspy-security +claude -p "You are the SECURITY Task Master. Read docs/taskmaster/PLAN.md and implement all S1-S4 issues. Update status.json when done." + +# Architecture Task Master (can start in parallel) +cd ../mcilspy-arch +claude -p "You are the ARCHITECTURE Task Master. Read docs/taskmaster/PLAN.md and implement all A1-A8 issues. Check status.json for security patterns first." + +# Performance Task Master +cd ../mcilspy-perf +claude -p "You are the PERFORMANCE Task Master. Read docs/taskmaster/PLAN.md and implement all P1-P8 issues." + +# Testing Task Master (runs last) +cd ../mcilspy-testing +claude -p "You are the TESTING Task Master. Read docs/taskmaster/PLAN.md and implement all T1-T7 issues. Requires other domains complete first." +``` + +### Merge Protocol + +```bash +# After all complete, merge in order: +git checkout main +git merge --no-ff fix/security -m "security: path validation, temp cleanup, output limits" +git merge --no-ff fix/architecture -m "refactor: consolidate utils, constants, models" +git merge --no-ff fix/performance -m "perf: string heap search, pagination, caching" +git merge --no-ff fix/testing -m "test: integration tests, tool coverage, fixtures" + +# Cleanup +git worktree remove ../mcilspy-security +git worktree remove ../mcilspy-arch +git worktree remove ../mcilspy-perf +git worktree remove ../mcilspy-testing +``` + +--- + +## Priority Matrix + +``` + IMPACT + Low Medium High + ┌────────┬─────────┬─────────┐ + High │ A5 │ P2,P3 │ S1-S4 │ ← Fix First + EFFORT │ A8 │ A6,A7 │ A1-A4 │ + Medium │ T5 │ P5-P8 │ P1,P4 │ + │ D4 │ T3,T4 │ T1,T2 │ + Low │ D5,D6 │ T6 │ T7 │ ← Quick Wins + └────────┴─────────┴─────────┘ +``` + +**Quick Wins** (do first within each domain): +- S2: Temp directory fix (5 min) +- A6: Constants file (15 min) +- P7: __exit__ fix (2 min) +- T7: Add test assembly (30 min) + +--- + +## Definition of Done + +- [ ] All issues in domain addressed +- [ ] Tests pass: `uv run pytest` +- [ ] Lint passes: `uv run ruff check src/` +- [ ] Types pass: `uv run ruff check src/ --select=ANN` +- [ ] status.json updated to "ready" +- [ ] PR draft created (not merged) diff --git a/docs/taskmaster/status.json b/docs/taskmaster/status.json new file mode 100644 index 0000000..823ff6f --- /dev/null +++ b/docs/taskmaster/status.json @@ -0,0 +1,11 @@ +{ + "project": "mcilspy-code-review-fixes", + "created": "2025-02-08T00:00:00Z", + "domains": { + "security": { "status": "ready", "branch": "fix/security", "priority": 1 }, + "architecture": { "status": "pending", "branch": "fix/architecture", "priority": 2 }, + "performance": { "status": "pending", "branch": "fix/performance", "priority": 3 }, + "testing": { "status": "pending", "branch": "fix/testing", "priority": 4 } + }, + "merge_order": ["security", "architecture", "performance", "testing"] +} diff --git a/src/mcilspy/ilspy_wrapper.py b/src/mcilspy/ilspy_wrapper.py index a353357..6f02a82 100644 --- a/src/mcilspy/ilspy_wrapper.py +++ b/src/mcilspy/ilspy_wrapper.py @@ -22,6 +22,10 @@ from .models import ( logger = logging.getLogger(__name__) +# Maximum bytes to read from subprocess stdout/stderr to prevent memory exhaustion +# from malicious or corrupted assemblies that produce huge output +MAX_OUTPUT_BYTES = 50_000_000 # 50 MB + class ILSpyWrapper: """Wrapper class for ILSpy command line tool.""" @@ -85,6 +89,10 @@ class ILSpyWrapper: Returns: Tuple of (return_code, stdout, stderr) + + Note: + Output is truncated to MAX_OUTPUT_BYTES to prevent memory exhaustion + from malicious or corrupted assemblies. """ cmd = [self.ilspycmd_path] + args logger.debug(f"Running command: {' '.join(cmd)}") @@ -111,9 +119,33 @@ class ILSpyWrapper: await process.wait() # Ensure process is cleaned up return -1, "", "Command timed out after 5 minutes. The assembly may be corrupted or too complex." + # Truncate output if it exceeds the limit to prevent memory exhaustion + stdout_truncated = False + stderr_truncated = False + + if stdout_bytes and len(stdout_bytes) > MAX_OUTPUT_BYTES: + stdout_bytes = stdout_bytes[:MAX_OUTPUT_BYTES] + stdout_truncated = True + logger.warning( + f"stdout truncated from {len(stdout_bytes)} to {MAX_OUTPUT_BYTES} bytes" + ) + + if stderr_bytes and len(stderr_bytes) > MAX_OUTPUT_BYTES: + stderr_bytes = stderr_bytes[:MAX_OUTPUT_BYTES] + stderr_truncated = True + logger.warning( + f"stderr truncated from {len(stderr_bytes)} to {MAX_OUTPUT_BYTES} bytes" + ) + stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else "" stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else "" + # Add truncation warning to output + if stdout_truncated: + stdout += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]" + if stderr_truncated: + stderr += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]" + return process.returncode, stdout, stderr except Exception as e: @@ -136,6 +168,28 @@ class ILSpyWrapper: assembly_name=os.path.basename(request.assembly_path), ) + # Use TemporaryDirectory context manager for guaranteed cleanup (no race condition) + # when user doesn't specify an output directory + if request.output_dir: + # User specified output directory - use it directly + return await self._decompile_to_dir(request, request.output_dir) + else: + # Create a temporary directory with guaranteed cleanup + with tempfile.TemporaryDirectory() as temp_dir: + return await self._decompile_to_dir(request, temp_dir) + + async def _decompile_to_dir( + self, request: DecompileRequest, output_dir: str + ) -> DecompileResponse: + """Internal helper to perform decompilation to a specific directory. + + Args: + request: Decompilation request + output_dir: Directory to write output to + + Returns: + Decompilation response + """ args = [request.assembly_path] # Add language version @@ -145,13 +199,6 @@ class ILSpyWrapper: if request.type_name: args.extend(["-t", request.type_name]) - # Add output directory if specified - temp_dir = None - output_dir = request.output_dir - if not output_dir: - temp_dir = tempfile.mkdtemp() - output_dir = temp_dir - args.extend(["-o", output_dir]) # Add project creation flag @@ -190,7 +237,7 @@ class ILSpyWrapper: assembly_name = os.path.splitext(os.path.basename(request.assembly_path))[0] if return_code == 0: - # If no output directory was specified, return stdout as source code + # If no output directory was specified by user, return stdout as source code source_code = None output_path = None @@ -237,10 +284,6 @@ class ILSpyWrapper: assembly_name=os.path.basename(request.assembly_path), type_name=request.type_name, ) - finally: - # Clean up temporary directory - if temp_dir and os.path.exists(temp_dir): - shutil.rmtree(temp_dir, ignore_errors=True) async def list_types(self, request: ListTypesRequest) -> ListTypesResponse: """List types in a .NET assembly. @@ -464,8 +507,8 @@ class ILSpyWrapper: # Try to extract more info by decompiling assembly attributes # Decompile with minimal output to get assembly-level attributes - temp_dir = tempfile.mkdtemp() - try: + # Use TemporaryDirectory context manager for guaranteed cleanup (no race condition) + with tempfile.TemporaryDirectory() as temp_dir: args = [ request.assembly_path, "-o", @@ -523,10 +566,6 @@ class ILSpyWrapper: if title_match: full_name = f"{title_match.group(1)}, Version={version}" - finally: - # Clean up temp directory - shutil.rmtree(temp_dir, ignore_errors=True) - return AssemblyInfo( name=name, version=version, diff --git a/src/mcilspy/metadata_reader.py b/src/mcilspy/metadata_reader.py index 2a2e1c6..c094bbf 100644 --- a/src/mcilspy/metadata_reader.py +++ b/src/mcilspy/metadata_reader.py @@ -18,6 +18,10 @@ from dnfile.mdtable import TypeDefRow logger = logging.getLogger(__name__) +# Maximum assembly file size to load (in megabytes) +# Prevents memory exhaustion from extremely large or malicious assemblies +MAX_ASSEMBLY_SIZE_MB = 500 + @dataclass class MethodInfo: @@ -101,6 +105,12 @@ class AssemblyMetadata: referenced_assemblies: list[str] = field(default_factory=list) +class AssemblySizeError(ValueError): + """Raised when an assembly exceeds the maximum allowed size.""" + + pass + + class MetadataReader: """Read .NET assembly metadata directly using dnfile.""" @@ -109,11 +119,25 @@ class MetadataReader: Args: assembly_path: Path to the .NET assembly file + + Raises: + FileNotFoundError: If the assembly file doesn't exist + AssemblySizeError: If the assembly exceeds MAX_ASSEMBLY_SIZE_MB """ self.assembly_path = Path(assembly_path) if not self.assembly_path.exists(): raise FileNotFoundError(f"Assembly not found: {assembly_path}") + # Check file size before loading to prevent memory exhaustion + file_size_bytes = self.assembly_path.stat().st_size + max_size_bytes = MAX_ASSEMBLY_SIZE_MB * 1024 * 1024 + if file_size_bytes > max_size_bytes: + size_mb = file_size_bytes / (1024 * 1024) + raise AssemblySizeError( + f"Assembly file size ({size_mb:.1f} MB) exceeds maximum allowed " + f"({MAX_ASSEMBLY_SIZE_MB} MB). This limit prevents memory exhaustion." + ) + self._pe: dnfile.dnPE | None = None self._type_cache: dict[int, TypeDefRow] = {} diff --git a/src/mcilspy/server.py b/src/mcilspy/server.py index 7807c6e..4a95e50 100644 --- a/src/mcilspy/server.py +++ b/src/mcilspy/server.py @@ -96,6 +96,57 @@ def _format_error(error: Exception, context: str = "") -> str: return f"**Error**: {error_msg}" +class AssemblyPathError(ValueError): + """Raised when an assembly path fails validation.""" + + pass + + +def _validate_assembly_path(assembly_path: str) -> str: + """Validate and normalize an assembly path for security. + + Performs the following checks: + 1. Path is not empty + 2. Resolves to an absolute path (prevents path traversal) + 3. File exists and is a regular file (not a directory or symlink to directory) + 4. Has a valid .NET assembly extension (.dll or .exe) + + Args: + assembly_path: User-provided path to a .NET assembly + + Returns: + Absolute, validated path to the assembly + + Raises: + AssemblyPathError: If the path fails any validation check + """ + if not assembly_path or not assembly_path.strip(): + raise AssemblyPathError("Assembly path cannot be empty") + + # Resolve to absolute path (handles .., symlinks, etc.) + try: + resolved_path = os.path.realpath(os.path.expanduser(assembly_path.strip())) + except (OSError, ValueError) as e: + raise AssemblyPathError(f"Invalid path: {e}") from e + + # Check if path exists + if not os.path.exists(resolved_path): + raise AssemblyPathError(f"Assembly file not found: {resolved_path}") + + # Check if it's a regular file (not a directory) + if not os.path.isfile(resolved_path): + raise AssemblyPathError(f"Path is not a file: {resolved_path}") + + # Validate extension + _, ext = os.path.splitext(resolved_path) + if ext.lower() not in (".dll", ".exe"): + raise AssemblyPathError( + f"Invalid assembly extension '{ext}'. Expected .dll or .exe" + ) + + return resolved_path + + def _find_ilspycmd_path() -> str | None: """Find ilspycmd in PATH or common install locations.""" # Check PATH first @@ -562,8 +613,14 @@ async def decompile_assembly( show_il_sequence_points: Include debugging sequence points in IL output (implies show_il_code) nested_directories: Organize output files in namespace-based directory hierarchy """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Starting decompilation of assembly: {assembly_path}") + await ctx.info(f"Starting decompilation of assembly: {validated_path}") try: wrapper = get_wrapper(ctx) @@ -572,7 +629,7 @@ async def decompile_assembly( from .models import DecompileRequest request = DecompileRequest( - assembly_path=assembly_path, + assembly_path=validated_path, output_dir=output_dir, type_name=type_name, language_version=LanguageVersion(language_version), @@ -633,8 +690,14 @@ async def list_types( - "enum" or "e" Example: ["class", "interface"] or ["c", "i"] """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Listing types in assembly: {assembly_path}") + await ctx.info(f"Listing types in assembly: {validated_path}") try: wrapper = get_wrapper(ctx) @@ -654,12 +717,12 @@ async def list_types( from .models import ListTypesRequest - request = ListTypesRequest(assembly_path=assembly_path, entity_types=entity_type_enums) + request = ListTypesRequest(assembly_path=validated_path, entity_types=entity_type_enums) response = await wrapper.list_types(request) if response.success and response.types: - content = f"# Types in {assembly_path}\n\n" + content = f"# Types in {validated_path}\n\n" content += f"Found {response.total_count} types:\n\n" # Group by namespace @@ -713,8 +776,14 @@ async def generate_diagrammer( include_pattern: Regex to whitelist types (e.g., "MyApp\\\\.Services\\\\..+" for Services namespace) exclude_pattern: Regex to blacklist types (e.g., ".*Generated.*" to hide generated code) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Generating assembly diagram: {assembly_path}") + await ctx.info(f"Generating assembly diagram: {validated_path}") try: wrapper = get_wrapper(ctx) @@ -722,7 +791,7 @@ async def generate_diagrammer( from .models import GenerateDiagrammerRequest request = GenerateDiagrammerRequest( - assembly_path=assembly_path, + assembly_path=validated_path, output_dir=output_dir, include_pattern=include_pattern, exclude_pattern=exclude_pattern, @@ -756,15 +825,21 @@ async def get_assembly_info(assembly_path: str, ctx: Context | None = None) -> s Args: assembly_path: Full path to the .NET assembly file (.dll or .exe) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Getting assembly info: {assembly_path}") + await ctx.info(f"Getting assembly info: {validated_path}") try: wrapper = get_wrapper(ctx) from .models import AssemblyInfoRequest - request = AssemblyInfoRequest(assembly_path=assembly_path) + request = AssemblyInfoRequest(assembly_path=validated_path) info = await wrapper.get_assembly_info(request) @@ -817,8 +892,14 @@ async def search_types( case_sensitive: Whether pattern matching is case-sensitive (default: False) use_regex: Treat pattern as regular expression (default: False) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Searching for types matching '{pattern}' in: {assembly_path}") + await ctx.info(f"Searching for types matching '{pattern}' in: {validated_path}") try: wrapper = get_wrapper(ctx) @@ -838,7 +919,7 @@ async def search_types( from .models import ListTypesRequest - request = ListTypesRequest(assembly_path=assembly_path, entity_types=entity_type_enums) + request = ListTypesRequest(assembly_path=validated_path, entity_types=entity_type_enums) response = await wrapper.list_types(request) if not response.success: @@ -936,8 +1017,14 @@ async def search_strings( use_regex: Treat pattern as regular expression (default: False) max_results: Maximum number of matches to return (default: 100) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Searching for strings matching '{pattern}' in: {assembly_path}") + await ctx.info(f"Searching for strings matching '{pattern}' in: {validated_path}") try: wrapper = get_wrapper(ctx) @@ -946,7 +1033,7 @@ async def search_strings( from .models import DecompileRequest request = DecompileRequest( - assembly_path=assembly_path, + assembly_path=validated_path, show_il_code=True, # IL makes string literals explicit language_version=LanguageVersion.LATEST, ) @@ -1081,13 +1168,19 @@ async def search_methods( case_sensitive: Whether pattern matching is case-sensitive (default: False) use_regex: Treat pattern as regular expression (default: False) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Searching for methods matching '{pattern}' in: {assembly_path}") + await ctx.info(f"Searching for methods matching '{pattern}' in: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: methods = reader.list_methods( type_filter=type_filter, namespace_filter=namespace_filter, @@ -1198,13 +1291,19 @@ async def search_fields( case_sensitive: Whether pattern matching is case-sensitive (default: False) use_regex: Treat pattern as regular expression (default: False) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Searching for fields matching '{pattern}' in: {assembly_path}") + await ctx.info(f"Searching for fields matching '{pattern}' in: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: fields = reader.list_fields( type_filter=type_filter, namespace_filter=namespace_filter, @@ -1307,13 +1406,19 @@ async def search_properties( case_sensitive: Whether pattern matching is case-sensitive (default: False) use_regex: Treat pattern as regular expression (default: False) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Searching for properties matching '{pattern}' in: {assembly_path}") + await ctx.info(f"Searching for properties matching '{pattern}' in: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: properties = reader.list_properties( type_filter=type_filter, namespace_filter=namespace_filter, @@ -1398,13 +1503,19 @@ async def list_events( type_filter: Only list events in types containing this string namespace_filter: Only list events in namespaces containing this string """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Listing events in: {assembly_path}") + await ctx.info(f"Listing events in: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: events = reader.list_events( type_filter=type_filter, namespace_filter=namespace_filter, @@ -1455,13 +1566,19 @@ async def list_resources( Args: assembly_path: Full path to the .NET assembly file (.dll or .exe) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Listing resources in: {assembly_path}") + await ctx.info(f"Listing resources in: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: resources = reader.list_resources() if not resources: @@ -1502,13 +1619,19 @@ async def get_metadata_summary( Args: assembly_path: Full path to the .NET assembly file (.dll or .exe) """ + # Validate assembly path before any processing + try: + validated_path = _validate_assembly_path(assembly_path) + except AssemblyPathError as e: + return _format_error(e, "path validation") + if ctx: - await ctx.info(f"Getting metadata summary: {assembly_path}") + await ctx.info(f"Getting metadata summary: {validated_path}") try: from .metadata_reader import MetadataReader - with MetadataReader(assembly_path) as reader: + with MetadataReader(validated_path) as reader: meta = reader.get_assembly_metadata() content = "# Assembly Metadata Summary\n\n" diff --git a/tests/test_security.py b/tests/test_security.py new file mode 100644 index 0000000..4876a1e --- /dev/null +++ b/tests/test_security.py @@ -0,0 +1,288 @@ +"""Security-focused tests for mcilspy validation functions. + +These tests verify the security hardening in S1-S4: +- S1: Path traversal prevention via _validate_assembly_path() +- S2: Temp directory race condition fix (structural - uses TemporaryDirectory) +- S3: Subprocess output size limits (MAX_OUTPUT_BYTES) +- S4: Assembly file size limits (MAX_ASSEMBLY_SIZE_MB) +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + + +class TestValidateAssemblyPath: + """Tests for S1: _validate_assembly_path() security validation.""" + + def test_empty_path_rejected(self): + """Empty paths should be rejected.""" + from mcilspy.server import AssemblyPathError, _validate_assembly_path + + with pytest.raises(AssemblyPathError, match="cannot be empty"): + _validate_assembly_path("") + + with pytest.raises(AssemblyPathError, match="cannot be empty"): + _validate_assembly_path(" ") + + def test_nonexistent_file_rejected(self): + """Non-existent files should be rejected.""" + from mcilspy.server import AssemblyPathError, _validate_assembly_path + + with pytest.raises(AssemblyPathError, match="not found"): + _validate_assembly_path("/nonexistent/path/to/assembly.dll") + + def test_directory_rejected(self): + """Directories should be rejected (file required).""" + from mcilspy.server import AssemblyPathError, _validate_assembly_path + + with pytest.raises(AssemblyPathError, match="not a file"): + _validate_assembly_path("/tmp") + + def test_invalid_extension_rejected(self): + """Files without .dll or .exe extension should be rejected.""" + from mcilspy.server import AssemblyPathError, _validate_assembly_path + + # Create a temp file with wrong extension + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + temp_path = f.name + + try: + with pytest.raises(AssemblyPathError, match="Invalid assembly extension"): + _validate_assembly_path(temp_path) + finally: + os.unlink(temp_path) + + def test_valid_dll_accepted(self): + """Valid .dll files should be accepted.""" + from mcilspy.server import _validate_assembly_path + + # Create a temp file with .dll extension + with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: + temp_path = f.name + + try: + result = _validate_assembly_path(temp_path) + # Should return absolute path + assert os.path.isabs(result) + assert result.endswith(".dll") + finally: + os.unlink(temp_path) + + def test_valid_exe_accepted(self): + """Valid .exe files should be accepted.""" + from mcilspy.server import _validate_assembly_path + + # Create a temp file with .exe extension + with tempfile.NamedTemporaryFile(suffix=".exe", delete=False) as f: + temp_path = f.name + + try: + result = _validate_assembly_path(temp_path) + # Should return absolute path + assert os.path.isabs(result) + assert result.endswith(".exe") + finally: + os.unlink(temp_path) + + def test_case_insensitive_extension(self): + """Extension check should be case-insensitive.""" + from mcilspy.server import _validate_assembly_path + + # Create temp files with mixed case extensions + for ext in [".DLL", ".Dll", ".EXE", ".Exe"]: + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f: + temp_path = f.name + + try: + result = _validate_assembly_path(temp_path) + assert os.path.isabs(result) + finally: + os.unlink(temp_path) + + def test_path_traversal_resolved(self): + """Path traversal attempts should be resolved to absolute paths.""" + from mcilspy.server import _validate_assembly_path + + # Create a temp file + with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: + temp_path = f.name + + try: + # Create a path with .. components + parent = os.path.dirname(temp_path) + filename = os.path.basename(temp_path) + traversal_path = os.path.join(parent, "..", os.path.basename(parent), filename) + + result = _validate_assembly_path(traversal_path) + # Should resolve to absolute path without .. + assert ".." not in result + assert os.path.isabs(result) + finally: + os.unlink(temp_path) + + def test_tilde_expansion(self): + """Home directory tilde should be expanded.""" + from mcilspy.server import AssemblyPathError, _validate_assembly_path + + # This should fail because the file doesn't exist, but the path should be expanded + with pytest.raises(AssemblyPathError, match="not found"): + _validate_assembly_path("~/nonexistent.dll") + + +class TestMaxOutputBytes: + """Tests for S3: MAX_OUTPUT_BYTES constant and truncation.""" + + def test_constant_defined(self): + """MAX_OUTPUT_BYTES constant should be defined.""" + from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES + + assert MAX_OUTPUT_BYTES == 50_000_000 # 50 MB + assert MAX_OUTPUT_BYTES > 0 + + def test_constant_reasonable_size(self): + """MAX_OUTPUT_BYTES should be a reasonable size (not too small, not too large).""" + from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES + + # Should be at least 1 MB + assert MAX_OUTPUT_BYTES >= 1_000_000 + # Should be at most 100 MB to prevent memory issues + assert MAX_OUTPUT_BYTES <= 100_000_000 + + +class TestMaxAssemblySize: + """Tests for S4: MAX_ASSEMBLY_SIZE_MB constant and size check.""" + + def test_constant_defined(self): + """MAX_ASSEMBLY_SIZE_MB constant should be defined.""" + from mcilspy.metadata_reader import MAX_ASSEMBLY_SIZE_MB + + assert MAX_ASSEMBLY_SIZE_MB == 500 # 500 MB + assert MAX_ASSEMBLY_SIZE_MB > 0 + + def test_assembly_size_error_defined(self): + """AssemblySizeError exception should be defined.""" + from mcilspy.metadata_reader import AssemblySizeError + + assert issubclass(AssemblySizeError, ValueError) + + def test_oversized_file_rejected(self): + """Files exceeding MAX_ASSEMBLY_SIZE_MB should be rejected.""" + from mcilspy.metadata_reader import ( + AssemblySizeError, + MAX_ASSEMBLY_SIZE_MB, + MetadataReader, + ) + + # Create a temp file + with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: + temp_path = f.name + + try: + # Mock the file size to be over the limit + mock_stat_result = os.stat(temp_path) + oversized_bytes = (MAX_ASSEMBLY_SIZE_MB + 1) * 1024 * 1024 + + with patch.object(Path, "stat") as mock_stat: + mock_stat.return_value = type( + "StatResult", + (), + {"st_size": oversized_bytes}, + )() + + with pytest.raises(AssemblySizeError, match="exceeds maximum"): + MetadataReader(temp_path) + finally: + os.unlink(temp_path) + + def test_normal_sized_file_accepted(self): + """Files under MAX_ASSEMBLY_SIZE_MB should be accepted (at init).""" + from mcilspy.metadata_reader import MetadataReader + + # Create a small temp file + with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: + f.write(b"small content") + temp_path = f.name + + try: + # Should not raise on init (will fail later when trying to parse) + reader = MetadataReader(temp_path) + assert reader.assembly_path.exists() + finally: + os.unlink(temp_path) + + +class TestTemporaryDirectoryUsage: + """Tests for S2: TemporaryDirectory context manager usage. + + These are structural tests that verify the code uses the secure pattern. + """ + + def test_decompile_uses_temp_directory_context(self): + """Verify decompile method structure uses TemporaryDirectory.""" + import inspect + + from mcilspy.ilspy_wrapper import ILSpyWrapper + + # Get source code of decompile method + source = inspect.getsource(ILSpyWrapper.decompile) + + # Should use TemporaryDirectory context manager + assert "tempfile.TemporaryDirectory()" in source + assert "with tempfile.TemporaryDirectory()" in source + + # Should NOT use the old mkdtemp pattern + assert "tempfile.mkdtemp()" not in source + + def test_get_assembly_info_uses_temp_directory_context(self): + """Verify get_assembly_info method structure uses TemporaryDirectory.""" + import inspect + + from mcilspy.ilspy_wrapper import ILSpyWrapper + + # Get source code of get_assembly_info method + source = inspect.getsource(ILSpyWrapper.get_assembly_info) + + # Should use TemporaryDirectory context manager + assert "with tempfile.TemporaryDirectory()" in source + + # Should NOT use the old mkdtemp pattern + assert "tempfile.mkdtemp()" not in source + + +class TestPathValidationIntegration: + """Integration tests to verify path validation is applied to all tools.""" + + def test_all_tools_have_path_validation(self): + """Verify all assembly-accepting tools call _validate_assembly_path.""" + import inspect + + from mcilspy import server + + # List of tool functions that accept assembly_path + tools_with_assembly_path = [ + "decompile_assembly", + "list_types", + "generate_diagrammer", + "get_assembly_info", + "search_types", + "search_strings", + "search_methods", + "search_fields", + "search_properties", + "list_events", + "list_resources", + "get_metadata_summary", + ] + + for tool_name in tools_with_assembly_path: + tool_func = getattr(server, tool_name) + source = inspect.getsource(tool_func) + + # Each tool should call _validate_assembly_path + assert "_validate_assembly_path" in source, ( + f"Tool '{tool_name}' does not call _validate_assembly_path" + )