Compare commits
12 Commits
7d784af17c
...
3c21b9d640
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c21b9d640 | |||
| db95aeb491 | |||
| cf2c5159b3 | |||
| 70c4a4a39a | |||
| 3d7a561f20 | |||
| 8f119b72c2 | |||
| 20d0cd2e3a | |||
| 4bd9ce19af | |||
| 3b6afd0646 | |||
| fa71150ed5 | |||
| 16854b77ee | |||
| 8901752ae3 |
221
docs/taskmaster/PLAN.md
Normal file
221
docs/taskmaster/PLAN.md
Normal file
@ -0,0 +1,221 @@
|
||||
# Taskmaster Execution Plan: mcilspy Code Review Fixes
|
||||
|
||||
## Overview
|
||||
|
||||
33 issues identified in hater-hat code review, organized into 4 parallel workstreams.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────┐
|
||||
│ MERGE ORDER (Sequential) │
|
||||
│ security → architecture → performance → testing │
|
||||
└─────────────────────────────────────────────────────────────────┘
|
||||
│ │ │ │
|
||||
▼ ▼ ▼ ▼
|
||||
┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
|
||||
│SECURITY │ │ ARCH │ │ PERF │ │ TESTING │
|
||||
│ 4 issues│ │ 8 issues│ │ 8 issues│ │ 7 issues│
|
||||
│ P1-CRIT │ │ P2-HIGH │ │ P3-MED │ │ P4-LOW │
|
||||
└─────────┘ └─────────┘ └─────────┘ └─────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Domain 1: SECURITY (Priority 1 - Critical)
|
||||
|
||||
**Branch**: `fix/security`
|
||||
**Estimated Effort**: 2-3 hours
|
||||
**Blocks**: All other domains should wait for security review patterns
|
||||
|
||||
### Issues to Fix
|
||||
|
||||
| ID | Issue | File:Line | Fix |
|
||||
|----|-------|-----------|-----|
|
||||
| S1 | Path traversal - no validation | `server.py:*`, `ilspy_wrapper.py:132` | Add `_validate_assembly_path()` helper that checks: exists, is file, has .dll/.exe extension, resolves to absolute path, optionally check for PE signature |
|
||||
| S2 | Temp directory race condition | `ilspy_wrapper.py:150-153` | Replace `tempfile.mkdtemp()` with `tempfile.TemporaryDirectory()` context manager |
|
||||
| S3 | Unbounded subprocess output | `ilspy_wrapper.py:104-107` | Add `MAX_OUTPUT_BYTES = 50_000_000` constant, truncate stdout/stderr if exceeded |
|
||||
| S4 | No file size limit before loading | `metadata_reader.py:113-115` | Add `MAX_ASSEMBLY_SIZE_MB = 500` check before `dnfile.dnPE()` |
|
||||
|
||||
### Acceptance Criteria
|
||||
- [ ] All user-provided paths validated before use
|
||||
- [ ] No temp file leaks possible
|
||||
- [ ] Memory exhaustion attacks mitigated
|
||||
- [ ] Unit tests for validation functions
|
||||
|
||||
---
|
||||
|
||||
## Domain 2: ARCHITECTURE (Priority 2 - High)
|
||||
|
||||
**Branch**: `fix/architecture`
|
||||
**Estimated Effort**: 3-4 hours
|
||||
**Depends On**: Security (for validation patterns)
|
||||
|
||||
### Issues to Fix
|
||||
|
||||
| ID | Issue | File:Line | Fix |
|
||||
|----|-------|-----------|-----|
|
||||
| A1 | Duplicated PATH discovery | `server.py:99-123`, `ilspy_wrapper.py:41-75` | Extract to `src/mcilspy/utils.py` → `find_ilspycmd_path()` |
|
||||
| A2 | Mixed dataclass/Pydantic models | `metadata_reader.py` vs `models.py` | Convert metadata_reader dataclasses to Pydantic models in `models.py` |
|
||||
| A3 | Fragile lifespan context access | `server.py:51-80` | Simplify to module-level `_wrapper: ILSpyWrapper | None = None` with proper locking |
|
||||
| A4 | Stateless wrapper pretending to be stateful | `ilspy_wrapper.py` | Keep as-is but document why (caches ilspycmd_path lookup) OR convert to module functions |
|
||||
| A5 | Inconsistent async/sync | `metadata_reader.py` | Document as "CPU-bound sync" in docstring, add note about thread pool for heavy loads |
|
||||
| A6 | Magic numbers scattered | Throughout | Create `src/mcilspy/constants.py` with all limits/timeouts |
|
||||
| A7 | Repeated regex compilation | `server.py` (6 places) | Add `_compile_search_pattern(pattern, case_sensitive, use_regex)` helper |
|
||||
| A8 | Language version validation missing | `server.py:578` | Add try/except with helpful error listing valid versions |
|
||||
|
||||
### Acceptance Criteria
|
||||
- [ ] Single source of truth for PATH discovery
|
||||
- [ ] Consistent data model layer (all Pydantic)
|
||||
- [ ] All magic numbers in constants.py
|
||||
- [ ] Helper functions reduce code duplication
|
||||
|
||||
---
|
||||
|
||||
## Domain 3: PERFORMANCE (Priority 3 - Medium)
|
||||
|
||||
**Branch**: `fix/performance`
|
||||
**Estimated Effort**: 4-5 hours
|
||||
**Depends On**: Architecture (for new constants/utils)
|
||||
|
||||
### Issues to Fix
|
||||
|
||||
| ID | Issue | File:Line | Fix |
|
||||
|----|-------|-----------|-----|
|
||||
| P1 | search_strings decompiles entire assembly | `server.py:948-954` | Use dnfile's `pe.net.user_strings` to search string heap directly - 100x faster |
|
||||
| P2 | No result pagination | All list_* tools | Add `max_results: int = 1000` and `offset: int = 0` params, return `has_more` flag |
|
||||
| P3 | List conversion instead of generators | `metadata_reader.py:289` | Use `enumerate()` directly on iterator where possible |
|
||||
| P4 | No caching of decompilation | `ilspy_wrapper.py` | Add optional LRU cache keyed by (path, mtime, type_name) |
|
||||
| P5 | Silent failures in platform detection | `server.py:195-230` | Log warnings for permission errors, return reason in result |
|
||||
| P6 | Generic exception catches | Throughout | Replace with specific exceptions, preserve stack traces |
|
||||
| P7 | MetadataReader __exit__ type mismatch | `metadata_reader.py:595-597` | Fix return type, remove unnecessary `return False` |
|
||||
| P8 | No assembly validation | `ilspy_wrapper.py` | Add PE signature check (MZ header + PE\0\0) before subprocess |
|
||||
|
||||
### Acceptance Criteria
|
||||
- [ ] search_strings uses string heap (benchmark: 10x improvement)
|
||||
- [ ] All list tools support pagination
|
||||
- [ ] No silent swallowed exceptions
|
||||
- [ ] PE validation prevents wasted subprocess calls
|
||||
|
||||
---
|
||||
|
||||
## Domain 4: TESTING (Priority 4 - Enhancement)
|
||||
|
||||
**Branch**: `fix/testing`
|
||||
**Estimated Effort**: 5-6 hours
|
||||
**Depends On**: All other domains (tests should cover new code)
|
||||
|
||||
### Issues to Fix
|
||||
|
||||
| ID | Issue | Fix |
|
||||
|----|-------|-----|
|
||||
| T1 | No integration tests | Create `tests/integration/` with real assembly tests. Use a small test .dll checked into repo |
|
||||
| T2 | No MCP tool tests | Add `tests/test_server_tools.py` testing each `@mcp.tool()` function |
|
||||
| T3 | No error path tests | Add tests for: regex compilation failure, ilspycmd not found, ctx.info() failure |
|
||||
| T4 | No concurrency tests | Add `tests/test_concurrency.py` with parallel tool invocations |
|
||||
| T5 | Missing docstring validation | Add test that all public functions have docstrings (using AST) |
|
||||
| T6 | No cancel/progress tests | Test timeout behavior, verify progress reporting works |
|
||||
| T7 | Add test .NET assembly | Create minimal C# project, compile to .dll, check into `tests/fixtures/` |
|
||||
|
||||
### Acceptance Criteria
|
||||
- [ ] Integration tests with real ilspycmd calls
|
||||
- [ ] 80%+ code coverage (currently ~40% estimated)
|
||||
- [ ] All error paths tested
|
||||
- [ ] CI runs integration tests
|
||||
|
||||
---
|
||||
|
||||
## Deferred (Won't Fix This Sprint)
|
||||
|
||||
These issues are valid but lower priority:
|
||||
|
||||
| ID | Issue | Reason to Defer |
|
||||
|----|-------|-----------------|
|
||||
| D1 | No cancel/abort for decompilation | Requires MCP protocol support for cancellation |
|
||||
| D2 | No progress reporting | Needs ilspycmd changes or parsing stdout in real-time |
|
||||
| D3 | No resource extraction | Feature request, not bug - add to backlog |
|
||||
| D4 | install_ilspy sudo handling | Edge case - document limitation instead |
|
||||
| D5 | No dry-run for installation | Nice-to-have, not critical |
|
||||
| D6 | Error messages expose paths | Low risk for MCP (local tool) |
|
||||
|
||||
---
|
||||
|
||||
## Execution Commands
|
||||
|
||||
### Setup Worktrees (Run Once)
|
||||
|
||||
```bash
|
||||
# Create worktrees for parallel development
|
||||
git worktree add ../mcilspy-security fix/security -b fix/security
|
||||
git worktree add ../mcilspy-arch fix/architecture -b fix/architecture
|
||||
git worktree add ../mcilspy-perf fix/performance -b fix/performance
|
||||
git worktree add ../mcilspy-testing fix/testing -b fix/testing
|
||||
```
|
||||
|
||||
### Task Master Commands
|
||||
|
||||
```bash
|
||||
# Security Task Master
|
||||
cd ../mcilspy-security
|
||||
claude -p "You are the SECURITY Task Master. Read docs/taskmaster/PLAN.md and implement all S1-S4 issues. Update status.json when done."
|
||||
|
||||
# Architecture Task Master (can start in parallel)
|
||||
cd ../mcilspy-arch
|
||||
claude -p "You are the ARCHITECTURE Task Master. Read docs/taskmaster/PLAN.md and implement all A1-A8 issues. Check status.json for security patterns first."
|
||||
|
||||
# Performance Task Master
|
||||
cd ../mcilspy-perf
|
||||
claude -p "You are the PERFORMANCE Task Master. Read docs/taskmaster/PLAN.md and implement all P1-P8 issues."
|
||||
|
||||
# Testing Task Master (runs last)
|
||||
cd ../mcilspy-testing
|
||||
claude -p "You are the TESTING Task Master. Read docs/taskmaster/PLAN.md and implement all T1-T7 issues. Requires other domains complete first."
|
||||
```
|
||||
|
||||
### Merge Protocol
|
||||
|
||||
```bash
|
||||
# After all complete, merge in order:
|
||||
git checkout main
|
||||
git merge --no-ff fix/security -m "security: path validation, temp cleanup, output limits"
|
||||
git merge --no-ff fix/architecture -m "refactor: consolidate utils, constants, models"
|
||||
git merge --no-ff fix/performance -m "perf: string heap search, pagination, caching"
|
||||
git merge --no-ff fix/testing -m "test: integration tests, tool coverage, fixtures"
|
||||
|
||||
# Cleanup
|
||||
git worktree remove ../mcilspy-security
|
||||
git worktree remove ../mcilspy-arch
|
||||
git worktree remove ../mcilspy-perf
|
||||
git worktree remove ../mcilspy-testing
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Priority Matrix
|
||||
|
||||
```
|
||||
IMPACT
|
||||
Low Medium High
|
||||
┌────────┬─────────┬─────────┐
|
||||
High │ A5 │ P2,P3 │ S1-S4 │ ← Fix First
|
||||
EFFORT │ A8 │ A6,A7 │ A1-A4 │
|
||||
Medium │ T5 │ P5-P8 │ P1,P4 │
|
||||
│ D4 │ T3,T4 │ T1,T2 │
|
||||
Low │ D5,D6 │ T6 │ T7 │ ← Quick Wins
|
||||
└────────┴─────────┴─────────┘
|
||||
```
|
||||
|
||||
**Quick Wins** (do first within each domain):
|
||||
- S2: Temp directory fix (5 min)
|
||||
- A6: Constants file (15 min)
|
||||
- P7: __exit__ fix (2 min)
|
||||
- T7: Add test assembly (30 min)
|
||||
|
||||
---
|
||||
|
||||
## Definition of Done
|
||||
|
||||
- [ ] All issues in domain addressed
|
||||
- [ ] Tests pass: `uv run pytest`
|
||||
- [ ] Lint passes: `uv run ruff check src/`
|
||||
- [ ] Types pass: `uv run ruff check src/ --select=ANN`
|
||||
- [ ] status.json updated to "ready"
|
||||
- [ ] PR draft created (not merged)
|
||||
12
docs/taskmaster/status.json
Normal file
12
docs/taskmaster/status.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"project": "mcilspy-code-review-fixes",
|
||||
"created": "2025-02-08T00:00:00Z",
|
||||
"completed": "2025-02-08T12:00:00Z",
|
||||
"domains": {
|
||||
"security": { "status": "merged", "branch": "fix/security", "priority": 1 },
|
||||
"architecture": { "status": "merged", "branch": "fix/architecture", "priority": 2 },
|
||||
"performance": { "status": "merged", "branch": "fix/performance", "priority": 3 },
|
||||
"testing": { "status": "merged", "branch": "fix/testing", "priority": 4 }
|
||||
},
|
||||
"merge_order": ["security", "architecture", "performance", "testing"]
|
||||
}
|
||||
49
src/mcilspy/constants.py
Normal file
49
src/mcilspy/constants.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Constants and configuration values for mcilspy.
|
||||
|
||||
This module centralizes all timeouts, limits, and magic numbers used throughout
|
||||
the codebase. Import from here rather than hardcoding values.
|
||||
"""
|
||||
|
||||
# =============================================================================
|
||||
# Subprocess Timeouts
|
||||
# =============================================================================
|
||||
|
||||
# Maximum time to wait for ilspycmd decompilation (in seconds)
|
||||
# Large assemblies or corrupted files may take longer
|
||||
DECOMPILE_TIMEOUT_SECONDS: float = 300.0 # 5 minutes
|
||||
|
||||
# =============================================================================
|
||||
# Output Limits
|
||||
# =============================================================================
|
||||
|
||||
# Maximum characters to display from subprocess output in error messages
|
||||
MAX_ERROR_OUTPUT_CHARS: int = 1000
|
||||
|
||||
# Maximum line length when displaying code snippets (truncate longer lines)
|
||||
MAX_LINE_LENGTH: int = 200
|
||||
|
||||
# Maximum unparsed lines to log before suppressing (avoid log spam)
|
||||
MAX_UNPARSED_LOG_LINES: int = 3
|
||||
|
||||
# Preview length for unparsed line debug messages
|
||||
UNPARSED_LINE_PREVIEW_LENGTH: int = 100
|
||||
|
||||
# =============================================================================
|
||||
# Search Limits
|
||||
# =============================================================================
|
||||
|
||||
# Default maximum results for search operations
|
||||
DEFAULT_MAX_SEARCH_RESULTS: int = 100
|
||||
|
||||
# Maximum matches to display per type in grouped results
|
||||
MAX_MATCHES_PER_TYPE: int = 20
|
||||
|
||||
# =============================================================================
|
||||
# Tool Identifiers (for ilspycmd CLI)
|
||||
# =============================================================================
|
||||
|
||||
# Default entity types for list operations
|
||||
DEFAULT_ENTITY_TYPES: list[str] = ["class"]
|
||||
|
||||
# All supported entity types
|
||||
ALL_ENTITY_TYPES: list[str] = ["class", "interface", "struct", "delegate", "enum"]
|
||||
@ -9,6 +9,11 @@ import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from .constants import (
|
||||
DECOMPILE_TIMEOUT_SECONDS,
|
||||
MAX_UNPARSED_LOG_LINES,
|
||||
UNPARSED_LINE_PREVIEW_LENGTH,
|
||||
)
|
||||
from .models import (
|
||||
AssemblyInfo,
|
||||
AssemblyInfoRequest,
|
||||
@ -19,12 +24,63 @@ from .models import (
|
||||
ListTypesResponse,
|
||||
TypeInfo,
|
||||
)
|
||||
from .utils import find_ilspycmd_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Maximum bytes to read from subprocess stdout/stderr to prevent memory exhaustion
|
||||
# from malicious or corrupted assemblies that produce huge output
|
||||
MAX_OUTPUT_BYTES = 50_000_000 # 50 MB
|
||||
|
||||
# PE file signature constants
|
||||
_MZ_SIGNATURE = b"MZ" # DOS header magic number
|
||||
|
||||
|
||||
def _validate_pe_signature(file_path: str) -> tuple[bool, str]:
|
||||
"""Quick validation of PE file signature (MZ header).
|
||||
|
||||
Fails fast on non-PE files before invoking ilspycmd.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message). error_message is empty if valid.
|
||||
"""
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
header = f.read(2)
|
||||
if len(header) < 2:
|
||||
return False, "File is too small to be a valid PE file"
|
||||
if header != _MZ_SIGNATURE:
|
||||
return False, f"Not a valid PE file (missing MZ signature, got {header!r})"
|
||||
return True, ""
|
||||
except OSError as e:
|
||||
return False, f"Cannot read file: {e}"
|
||||
|
||||
|
||||
class ILSpyWrapper:
|
||||
"""Wrapper class for ILSpy command line tool."""
|
||||
"""Wrapper class for ILSpy command line tool.
|
||||
|
||||
This class encapsulates all interactions with the ilspycmd CLI tool.
|
||||
While the wrapper is stateless in terms of decompilation operations
|
||||
(each call is independent), it exists as a class to:
|
||||
|
||||
1. Cache the ilspycmd path lookup - Finding the executable involves
|
||||
checking PATH and several common installation locations, which
|
||||
is relatively expensive. Caching this on instantiation avoids
|
||||
repeated filesystem operations.
|
||||
|
||||
2. Provide a single point of configuration - If ilspycmd is not found,
|
||||
we fail fast at wrapper creation rather than on each tool call.
|
||||
|
||||
3. Enable future extensions - The class structure allows adding
|
||||
connection pooling, result caching, or other optimizations without
|
||||
changing the API.
|
||||
|
||||
The wrapper should be created once and reused across tool calls.
|
||||
See get_wrapper() in server.py for the recommended usage pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, ilspycmd_path: str | None = None) -> None:
|
||||
"""Initialize the wrapper.
|
||||
@ -32,48 +88,12 @@ class ILSpyWrapper:
|
||||
Args:
|
||||
ilspycmd_path: Path to ilspycmd executable. If None, will try to find it in PATH.
|
||||
"""
|
||||
self.ilspycmd_path = ilspycmd_path or self._find_ilspycmd()
|
||||
self.ilspycmd_path = ilspycmd_path or find_ilspycmd_path()
|
||||
if not self.ilspycmd_path:
|
||||
raise RuntimeError(
|
||||
"ILSpyCmd not found. Please install it with: dotnet tool install --global ilspycmd"
|
||||
)
|
||||
|
||||
def _find_ilspycmd(self) -> str | None:
|
||||
"""Find ilspycmd executable in PATH or common install locations.
|
||||
|
||||
Checks:
|
||||
1. Standard PATH (via shutil.which)
|
||||
2. ~/.dotnet/tools (default location for 'dotnet tool install --global')
|
||||
3. Platform-specific locations
|
||||
"""
|
||||
# Try standard PATH first
|
||||
for cmd_name in ["ilspycmd", "ilspycmd.exe"]:
|
||||
path = shutil.which(cmd_name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check common dotnet tools locations (not always in PATH)
|
||||
home = os.path.expanduser("~")
|
||||
dotnet_tools_paths = [
|
||||
os.path.join(home, ".dotnet", "tools", "ilspycmd"),
|
||||
os.path.join(home, ".dotnet", "tools", "ilspycmd.exe"),
|
||||
]
|
||||
|
||||
# Windows-specific paths
|
||||
if os.name == "nt":
|
||||
userprofile = os.environ.get("USERPROFILE", "")
|
||||
if userprofile:
|
||||
dotnet_tools_paths.extend([
|
||||
os.path.join(userprofile, ".dotnet", "tools", "ilspycmd.exe"),
|
||||
])
|
||||
|
||||
for tool_path in dotnet_tools_paths:
|
||||
if os.path.isfile(tool_path) and os.access(tool_path, os.X_OK):
|
||||
logger.info(f"Found ilspycmd at {tool_path} (not in PATH)")
|
||||
return tool_path
|
||||
|
||||
return None
|
||||
|
||||
async def _run_command(
|
||||
self, args: list[str], input_data: str | None = None
|
||||
) -> tuple[int, str, str]:
|
||||
@ -85,6 +105,10 @@ class ILSpyWrapper:
|
||||
|
||||
Returns:
|
||||
Tuple of (return_code, stdout, stderr)
|
||||
|
||||
Note:
|
||||
Output is truncated to MAX_OUTPUT_BYTES to prevent memory exhaustion
|
||||
from malicious or corrupted assemblies.
|
||||
"""
|
||||
cmd = [self.ilspycmd_path] + args
|
||||
logger.debug(f"Running command: {' '.join(cmd)}")
|
||||
@ -99,25 +123,50 @@ class ILSpyWrapper:
|
||||
|
||||
input_bytes = input_data.encode("utf-8") if input_data else None
|
||||
|
||||
# Timeout after 5 minutes to prevent hanging on malicious/corrupted assemblies
|
||||
# Timeout to prevent hanging on malicious/corrupted assemblies
|
||||
try:
|
||||
stdout_bytes, stderr_bytes = await asyncio.wait_for(
|
||||
process.communicate(input=input_bytes),
|
||||
timeout=300.0 # 5 minutes
|
||||
timeout=DECOMPILE_TIMEOUT_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Command timed out after 5 minutes, killing process")
|
||||
timeout_mins = DECOMPILE_TIMEOUT_SECONDS / 60
|
||||
logger.warning(f"Command timed out after {timeout_mins:.0f} minutes, killing process")
|
||||
process.kill()
|
||||
await process.wait() # Ensure process is cleaned up
|
||||
return -1, "", "Command timed out after 5 minutes. The assembly may be corrupted or too complex."
|
||||
return -1, "", f"Command timed out after {timeout_mins:.0f} minutes. The assembly may be corrupted or too complex."
|
||||
|
||||
# Truncate output if it exceeds the limit to prevent memory exhaustion
|
||||
stdout_truncated = False
|
||||
stderr_truncated = False
|
||||
|
||||
if stdout_bytes and len(stdout_bytes) > MAX_OUTPUT_BYTES:
|
||||
stdout_bytes = stdout_bytes[:MAX_OUTPUT_BYTES]
|
||||
stdout_truncated = True
|
||||
logger.warning(
|
||||
f"stdout truncated from {len(stdout_bytes)} to {MAX_OUTPUT_BYTES} bytes"
|
||||
)
|
||||
|
||||
if stderr_bytes and len(stderr_bytes) > MAX_OUTPUT_BYTES:
|
||||
stderr_bytes = stderr_bytes[:MAX_OUTPUT_BYTES]
|
||||
stderr_truncated = True
|
||||
logger.warning(
|
||||
f"stderr truncated from {len(stderr_bytes)} to {MAX_OUTPUT_BYTES} bytes"
|
||||
)
|
||||
|
||||
stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
|
||||
stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
|
||||
|
||||
# Add truncation warning to output
|
||||
if stdout_truncated:
|
||||
stdout += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
|
||||
if stderr_truncated:
|
||||
stderr += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
|
||||
|
||||
return process.returncode, stdout, stderr
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running command: {e}")
|
||||
except (OSError, FileNotFoundError) as e:
|
||||
logger.exception(f"Error running ilspycmd command: {e}")
|
||||
return -1, "", str(e)
|
||||
|
||||
async def decompile(self, request: DecompileRequest) -> DecompileResponse:
|
||||
@ -136,6 +185,37 @@ class ILSpyWrapper:
|
||||
assembly_name=os.path.basename(request.assembly_path),
|
||||
)
|
||||
|
||||
# Validate PE signature before invoking ilspycmd
|
||||
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
|
||||
if not is_valid:
|
||||
return DecompileResponse(
|
||||
success=False,
|
||||
error_message=pe_error,
|
||||
assembly_name=os.path.basename(request.assembly_path),
|
||||
)
|
||||
|
||||
# Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
|
||||
# when user doesn't specify an output directory
|
||||
if request.output_dir:
|
||||
# User specified output directory - use it directly
|
||||
return await self._decompile_to_dir(request, request.output_dir)
|
||||
else:
|
||||
# Create a temporary directory with guaranteed cleanup
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
return await self._decompile_to_dir(request, temp_dir)
|
||||
|
||||
async def _decompile_to_dir(
|
||||
self, request: DecompileRequest, output_dir: str
|
||||
) -> DecompileResponse:
|
||||
"""Internal helper to perform decompilation to a specific directory.
|
||||
|
||||
Args:
|
||||
request: Decompilation request
|
||||
output_dir: Directory to write output to
|
||||
|
||||
Returns:
|
||||
Decompilation response
|
||||
"""
|
||||
args = [request.assembly_path]
|
||||
|
||||
# Add language version
|
||||
@ -145,13 +225,6 @@ class ILSpyWrapper:
|
||||
if request.type_name:
|
||||
args.extend(["-t", request.type_name])
|
||||
|
||||
# Add output directory if specified
|
||||
temp_dir = None
|
||||
output_dir = request.output_dir
|
||||
if not output_dir:
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
output_dir = temp_dir
|
||||
|
||||
args.extend(["-o", output_dir])
|
||||
|
||||
# Add project creation flag
|
||||
@ -190,7 +263,7 @@ class ILSpyWrapper:
|
||||
assembly_name = os.path.splitext(os.path.basename(request.assembly_path))[0]
|
||||
|
||||
if return_code == 0:
|
||||
# If no output directory was specified, return stdout as source code
|
||||
# If no output directory was specified by user, return stdout as source code
|
||||
source_code = None
|
||||
output_path = None
|
||||
|
||||
@ -230,17 +303,14 @@ class ILSpyWrapper:
|
||||
type_name=request.type_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
logger.exception(f"Error during decompilation: {e}")
|
||||
return DecompileResponse(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
assembly_name=os.path.basename(request.assembly_path),
|
||||
type_name=request.type_name,
|
||||
)
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir and os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
async def list_types(self, request: ListTypesRequest) -> ListTypesResponse:
|
||||
"""List types in a .NET assembly.
|
||||
@ -256,6 +326,11 @@ class ILSpyWrapper:
|
||||
success=False, error_message=f"Assembly file not found: {request.assembly_path}"
|
||||
)
|
||||
|
||||
# Validate PE signature before invoking ilspycmd
|
||||
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
|
||||
if not is_valid:
|
||||
return ListTypesResponse(success=False, error_message=pe_error)
|
||||
|
||||
args = [request.assembly_path]
|
||||
|
||||
# Add entity types to list
|
||||
@ -279,7 +354,8 @@ class ILSpyWrapper:
|
||||
error_msg = stderr or stdout or "Unknown error occurred"
|
||||
return ListTypesResponse(success=False, error_message=error_msg)
|
||||
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
logger.exception(f"Error listing types: {e}")
|
||||
return ListTypesResponse(success=False, error_message=str(e))
|
||||
|
||||
# Compiled regex for parsing ilspycmd list output
|
||||
@ -326,10 +402,10 @@ class ILSpyWrapper:
|
||||
else:
|
||||
# Log unexpected lines (but don't fail - ilspycmd may output warnings/info)
|
||||
unparsed_count += 1
|
||||
if unparsed_count <= 3: # Avoid log spam
|
||||
logger.debug(f"Skipping unparsed line from ilspycmd: {line[:100]}")
|
||||
if unparsed_count <= MAX_UNPARSED_LOG_LINES:
|
||||
logger.debug(f"Skipping unparsed line from ilspycmd: {line[:UNPARSED_LINE_PREVIEW_LENGTH]}")
|
||||
|
||||
if unparsed_count > 3:
|
||||
if unparsed_count > MAX_UNPARSED_LOG_LINES:
|
||||
logger.debug(f"Skipped {unparsed_count} unparsed lines total")
|
||||
|
||||
return types
|
||||
@ -388,6 +464,11 @@ class ILSpyWrapper:
|
||||
"error_message": f"Assembly file not found: {request.assembly_path}",
|
||||
}
|
||||
|
||||
# Validate PE signature before invoking ilspycmd
|
||||
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
|
||||
if not is_valid:
|
||||
return {"success": False, "error_message": pe_error}
|
||||
|
||||
args = [request.assembly_path, "--generate-diagrammer"]
|
||||
|
||||
# Add output directory
|
||||
@ -433,7 +514,8 @@ class ILSpyWrapper:
|
||||
error_msg = stderr or stdout or "Unknown error occurred"
|
||||
return {"success": False, "error_message": error_msg}
|
||||
|
||||
except Exception as e:
|
||||
except OSError as e:
|
||||
logger.exception(f"Error generating diagrammer: {e}")
|
||||
return {"success": False, "error_message": str(e)}
|
||||
|
||||
async def get_assembly_info(self, request: AssemblyInfoRequest) -> AssemblyInfo:
|
||||
@ -448,6 +530,11 @@ class ILSpyWrapper:
|
||||
if not os.path.exists(request.assembly_path):
|
||||
raise FileNotFoundError(f"Assembly file not found: {request.assembly_path}")
|
||||
|
||||
# Validate PE signature before invoking ilspycmd
|
||||
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
|
||||
if not is_valid:
|
||||
raise ValueError(pe_error)
|
||||
|
||||
assembly_path = Path(request.assembly_path)
|
||||
|
||||
# Use ilspycmd to list types and extract assembly info from output
|
||||
@ -464,8 +551,8 @@ class ILSpyWrapper:
|
||||
|
||||
# Try to extract more info by decompiling assembly attributes
|
||||
# Decompile with minimal output to get assembly-level attributes
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
args = [
|
||||
request.assembly_path,
|
||||
"-o",
|
||||
@ -523,10 +610,6 @@ class ILSpyWrapper:
|
||||
if title_match:
|
||||
full_name = f"{title_match.group(1)}, Version={version}"
|
||||
|
||||
finally:
|
||||
# Clean up temp directory
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
return AssemblyInfo(
|
||||
name=name,
|
||||
version=version,
|
||||
|
||||
@ -4,101 +4,56 @@ Provides access to all 34+ CLR metadata tables without requiring ilspycmd.
|
||||
This enables searching for methods, fields, properties, events, and resources
|
||||
that are not exposed via the ilspycmd CLI.
|
||||
|
||||
This module contains CPU-bound synchronous code for parsing .NET PE metadata.
|
||||
For heavy workloads with many concurrent requests, consider running these
|
||||
operations in a thread pool (e.g., asyncio.to_thread) to avoid blocking
|
||||
the event loop.
|
||||
|
||||
Note: dnfile provides flag attributes as boolean properties (e.g., mdPublic, fdStatic)
|
||||
rather than traditional IntFlag enums, so we use those directly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
import re
|
||||
import struct
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import dnfile
|
||||
from dnfile.mdtable import TypeDefRow
|
||||
from dnfile.utils import read_compressed_int
|
||||
|
||||
from .models import (
|
||||
AssemblyMetadata,
|
||||
EventInfo,
|
||||
FieldInfo,
|
||||
MethodInfo,
|
||||
PropertyInfo,
|
||||
ResourceInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MethodInfo:
|
||||
"""Information about a method in an assembly."""
|
||||
# Maximum assembly file size to load (in megabytes)
|
||||
# Prevents memory exhaustion from extremely large or malicious assemblies
|
||||
MAX_ASSEMBLY_SIZE_MB = 500
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None
|
||||
return_type: str | None = None
|
||||
is_public: bool = False
|
||||
is_static: bool = False
|
||||
is_virtual: bool = False
|
||||
is_abstract: bool = False
|
||||
parameters: list[str] = field(default_factory=list)
|
||||
|
||||
class AssemblySizeError(ValueError):
|
||||
"""Raised when an assembly exceeds the maximum allowed size."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldInfo:
|
||||
"""Information about a field in an assembly."""
|
||||
class StringMatch:
|
||||
"""A matched string from the user strings heap."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None
|
||||
field_type: str | None = None
|
||||
is_public: bool = False
|
||||
is_static: bool = False
|
||||
is_literal: bool = False # Constant
|
||||
default_value: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class PropertyInfo:
|
||||
"""Information about a property in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None
|
||||
property_type: str | None = None
|
||||
has_getter: bool = False
|
||||
has_setter: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventInfo:
|
||||
"""Information about an event in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None
|
||||
event_type: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceInfo:
|
||||
"""Information about an embedded resource."""
|
||||
|
||||
name: str
|
||||
size: int
|
||||
is_public: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyMetadata:
|
||||
"""Complete assembly metadata from dnfile."""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
culture: str | None = None
|
||||
public_key_token: str | None = None
|
||||
target_framework: str | None = None
|
||||
type_count: int = 0
|
||||
method_count: int = 0
|
||||
field_count: int = 0
|
||||
property_count: int = 0
|
||||
event_count: int = 0
|
||||
resource_count: int = 0
|
||||
referenced_assemblies: list[str] = field(default_factory=list)
|
||||
value: str
|
||||
offset: int # Offset in the #US heap
|
||||
|
||||
|
||||
class MetadataReader:
|
||||
@ -109,11 +64,25 @@ class MetadataReader:
|
||||
|
||||
Args:
|
||||
assembly_path: Path to the .NET assembly file
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the assembly file doesn't exist
|
||||
AssemblySizeError: If the assembly exceeds MAX_ASSEMBLY_SIZE_MB
|
||||
"""
|
||||
self.assembly_path = Path(assembly_path)
|
||||
if not self.assembly_path.exists():
|
||||
raise FileNotFoundError(f"Assembly not found: {assembly_path}")
|
||||
|
||||
# Check file size before loading to prevent memory exhaustion
|
||||
file_size_bytes = self.assembly_path.stat().st_size
|
||||
max_size_bytes = MAX_ASSEMBLY_SIZE_MB * 1024 * 1024
|
||||
if file_size_bytes > max_size_bytes:
|
||||
size_mb = file_size_bytes / (1024 * 1024)
|
||||
raise AssemblySizeError(
|
||||
f"Assembly file size ({size_mb:.1f} MB) exceeds maximum allowed "
|
||||
f"({MAX_ASSEMBLY_SIZE_MB} MB). This limit prevents memory exhaustion."
|
||||
)
|
||||
|
||||
self._pe: dnfile.dnPE | None = None
|
||||
self._type_cache: dict[int, TypeDefRow] = {}
|
||||
|
||||
@ -122,7 +91,9 @@ class MetadataReader:
|
||||
if self._pe is None:
|
||||
try:
|
||||
self._pe = dnfile.dnPE(str(self.assembly_path))
|
||||
except Exception as e:
|
||||
except (OSError, struct.error) as e:
|
||||
# OSError/IOError: file access issues
|
||||
# struct.error: malformed PE structure
|
||||
raise ValueError(f"Failed to parse assembly: {e}") from e
|
||||
|
||||
# Build type cache for lookups
|
||||
@ -184,7 +155,8 @@ class MetadataReader:
|
||||
type_name = str(ca.Type) if ca.Type else ""
|
||||
if "TargetFramework" in type_name and hasattr(ca, "Value") and ca.Value:
|
||||
target_framework = str(ca.Value)
|
||||
except Exception:
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
# CustomAttribute parsing can fail in various ways due to blob format
|
||||
pass
|
||||
|
||||
type_count = (
|
||||
@ -583,6 +555,121 @@ class MetadataReader:
|
||||
|
||||
return resources
|
||||
|
||||
def _iter_user_strings(self) -> Iterator[tuple[int, str]]:
|
||||
"""Iterate over all user strings in the #US heap.
|
||||
|
||||
Yields (offset, string_value) tuples.
|
||||
|
||||
The #US (User Strings) heap stores UTF-16 encoded strings used in the
|
||||
assembly's IL code (ldstr instructions). Each entry is prefixed with a
|
||||
compressed integer length, followed by UTF-16 bytes and a trailing flag byte.
|
||||
"""
|
||||
pe = self._ensure_loaded()
|
||||
|
||||
if not pe.net or not pe.net.user_strings:
|
||||
return
|
||||
|
||||
heap = pe.net.user_strings
|
||||
data = heap._ClrStream__data__ # Access the raw bytes
|
||||
|
||||
if not data:
|
||||
return
|
||||
|
||||
# The first byte is always 0x00 (null string entry)
|
||||
offset = 1
|
||||
|
||||
while offset < len(data):
|
||||
# Read compressed integer length
|
||||
result = read_compressed_int(data[offset:])
|
||||
if result is None:
|
||||
break
|
||||
|
||||
length, size_bytes = result
|
||||
|
||||
if length == 0:
|
||||
offset += size_bytes
|
||||
continue
|
||||
|
||||
# Skip past the length bytes
|
||||
string_start = offset + size_bytes
|
||||
|
||||
if string_start + length > len(data):
|
||||
# Corrupted or truncated - stop iteration
|
||||
break
|
||||
|
||||
# Extract string data (UTF-16 with possible trailing flag byte)
|
||||
string_data = data[string_start : string_start + length]
|
||||
|
||||
# The trailing byte is a flag if length is odd
|
||||
if length % 2 == 1:
|
||||
string_data = string_data[:-1] # Remove flag byte
|
||||
|
||||
# Decode as UTF-16 Little Endian
|
||||
try:
|
||||
string_value = string_data.decode("utf-16-le", errors="replace")
|
||||
if string_value: # Only yield non-empty strings
|
||||
yield offset, string_value
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
# Skip malformed strings
|
||||
pass
|
||||
|
||||
# Move to next entry
|
||||
offset = string_start + length
|
||||
|
||||
def search_user_strings(
|
||||
self,
|
||||
pattern: str,
|
||||
case_sensitive: bool = False,
|
||||
use_regex: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> list[StringMatch]:
|
||||
"""Search for strings in the user strings heap.
|
||||
|
||||
This is much faster than decompiling the entire assembly because it
|
||||
reads directly from the #US metadata heap without invoking ilspycmd.
|
||||
|
||||
Args:
|
||||
pattern: String pattern to search for
|
||||
case_sensitive: Whether to match case (default: False)
|
||||
use_regex: Treat pattern as regular expression (default: False)
|
||||
max_results: Maximum number of matches to return (default: 100)
|
||||
|
||||
Returns:
|
||||
List of StringMatch objects containing matching strings
|
||||
"""
|
||||
matches: list[StringMatch] = []
|
||||
|
||||
# Compile regex if needed
|
||||
if use_regex:
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
try:
|
||||
search_pattern = re.compile(pattern, flags)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex pattern: {e}") from e
|
||||
else:
|
||||
search_pattern = None
|
||||
|
||||
# Prepare pattern for non-regex search
|
||||
if not use_regex and not case_sensitive:
|
||||
pattern_lower = pattern.lower()
|
||||
|
||||
for offset, string_value in self._iter_user_strings():
|
||||
if len(matches) >= max_results:
|
||||
break
|
||||
|
||||
# Check for match
|
||||
if use_regex and search_pattern is not None:
|
||||
if search_pattern.search(string_value):
|
||||
matches.append(StringMatch(value=string_value, offset=offset))
|
||||
elif case_sensitive:
|
||||
if pattern in string_value:
|
||||
matches.append(StringMatch(value=string_value, offset=offset))
|
||||
else:
|
||||
if pattern_lower in string_value.lower():
|
||||
matches.append(StringMatch(value=string_value, offset=offset))
|
||||
|
||||
return matches
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the PE file."""
|
||||
if self._pe:
|
||||
@ -594,4 +681,3 @@ class MetadataReader:
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
return False
|
||||
|
||||
@ -161,3 +161,84 @@ class AssemblyInfo(BaseModel):
|
||||
runtime_version: str | None = None
|
||||
is_signed: bool = False
|
||||
has_debug_info: bool = False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Metadata Reader Models (dnfile-based direct metadata access)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MethodInfo(BaseModel):
|
||||
"""Information about a method in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None = None
|
||||
return_type: str | None = None
|
||||
is_public: bool = False
|
||||
is_static: bool = False
|
||||
is_virtual: bool = False
|
||||
is_abstract: bool = False
|
||||
parameters: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class FieldInfo(BaseModel):
|
||||
"""Information about a field in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None = None
|
||||
field_type: str | None = None
|
||||
is_public: bool = False
|
||||
is_static: bool = False
|
||||
is_literal: bool = False # Constant
|
||||
default_value: str | None = None
|
||||
|
||||
|
||||
class PropertyInfo(BaseModel):
|
||||
"""Information about a property in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None = None
|
||||
property_type: str | None = None
|
||||
has_getter: bool = False
|
||||
has_setter: bool = False
|
||||
|
||||
|
||||
class EventInfo(BaseModel):
|
||||
"""Information about an event in an assembly."""
|
||||
|
||||
name: str
|
||||
full_name: str
|
||||
declaring_type: str
|
||||
namespace: str | None = None
|
||||
event_type: str | None = None
|
||||
|
||||
|
||||
class ResourceInfo(BaseModel):
|
||||
"""Information about an embedded resource."""
|
||||
|
||||
name: str
|
||||
size: int = 0
|
||||
is_public: bool = True
|
||||
|
||||
|
||||
class AssemblyMetadata(BaseModel):
|
||||
"""Complete assembly metadata from dnfile."""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
culture: str | None = None
|
||||
public_key_token: str | None = None
|
||||
target_framework: str | None = None
|
||||
type_count: int = 0
|
||||
method_count: int = 0
|
||||
field_count: int = 0
|
||||
property_count: int = 0
|
||||
event_count: int = 0
|
||||
resource_count: int = 0
|
||||
referenced_assemblies: list[str] = Field(default_factory=list)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
50
src/mcilspy/utils.py
Normal file
50
src/mcilspy/utils.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""Shared utility functions for mcilspy.
|
||||
|
||||
This module contains common utilities used across the codebase to avoid
|
||||
code duplication and ensure consistent behavior.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_ilspycmd_path() -> str | None:
|
||||
"""Find ilspycmd executable in PATH or common install locations.
|
||||
|
||||
This is the single source of truth for locating the ilspycmd binary.
|
||||
It checks:
|
||||
1. Standard PATH (via shutil.which)
|
||||
2. ~/.dotnet/tools (default location for 'dotnet tool install --global')
|
||||
3. Platform-specific locations (Windows %USERPROFILE%)
|
||||
|
||||
Returns:
|
||||
Path to ilspycmd executable if found, None otherwise
|
||||
"""
|
||||
# Check PATH first (handles both ilspycmd and ilspycmd.exe)
|
||||
for cmd_name in ["ilspycmd", "ilspycmd.exe"]:
|
||||
path = shutil.which(cmd_name)
|
||||
if path:
|
||||
return path
|
||||
|
||||
# Check common dotnet tools locations (not always in PATH for MCP servers)
|
||||
home = os.path.expanduser("~")
|
||||
candidates = [
|
||||
os.path.join(home, ".dotnet", "tools", "ilspycmd"),
|
||||
os.path.join(home, ".dotnet", "tools", "ilspycmd.exe"),
|
||||
]
|
||||
|
||||
# Windows-specific: also check USERPROFILE if different from ~
|
||||
if os.name == "nt":
|
||||
userprofile = os.environ.get("USERPROFILE", "")
|
||||
if userprofile and userprofile != home:
|
||||
candidates.append(os.path.join(userprofile, ".dotnet", "tools", "ilspycmd.exe"))
|
||||
|
||||
for candidate in candidates:
|
||||
if os.path.isfile(candidate) and os.access(candidate, os.X_OK):
|
||||
logger.info(f"Found ilspycmd at {candidate} (not in PATH)")
|
||||
return candidate
|
||||
|
||||
return None
|
||||
@ -1,33 +1,58 @@
|
||||
"""Shared pytest fixtures for mcilspy tests."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Path to test fixtures directory
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_assembly_path() -> str:
|
||||
"""Return path to the custom test assembly.
|
||||
|
||||
This is the primary fixture for tests - uses our custom-built
|
||||
TestAssembly.dll with known types and members.
|
||||
"""
|
||||
test_dll = FIXTURES_DIR / "TestAssembly.dll"
|
||||
if not test_dll.exists():
|
||||
pytest.skip("TestAssembly.dll not found - run build_test_assembly.sh first")
|
||||
return str(test_dll)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_assembly_path() -> str:
|
||||
"""Return path to a .NET assembly for testing.
|
||||
|
||||
Uses a known .NET SDK assembly that should exist on systems with dotnet installed.
|
||||
Falls back to SDK assemblies if test assembly not available.
|
||||
Prefer using test_assembly_path for new tests.
|
||||
"""
|
||||
# Try to find a .NET SDK assembly
|
||||
# First try our test assembly
|
||||
test_dll = FIXTURES_DIR / "TestAssembly.dll"
|
||||
if test_dll.exists():
|
||||
return str(test_dll)
|
||||
|
||||
# Fallback: Try to find a .NET SDK assembly
|
||||
dotnet_base = Path("/usr/share/dotnet/sdk")
|
||||
if dotnet_base.exists():
|
||||
# Find any SDK version
|
||||
for sdk_dir in dotnet_base.iterdir():
|
||||
test_dll = sdk_dir / "Sdks" / "Microsoft.NET.Sdk" / "tools" / "net10.0" / "Microsoft.NET.Build.Tasks.dll"
|
||||
if test_dll.exists():
|
||||
return str(test_dll)
|
||||
# Try older paths
|
||||
for net_version in ["net9.0", "net8.0", "net7.0", "net6.0"]:
|
||||
test_dll = sdk_dir / "Sdks" / "Microsoft.NET.Sdk" / "tools" / net_version / "Microsoft.NET.Build.Tasks.dll"
|
||||
for net_version in ["net10.0", "net9.0", "net8.0", "net7.0", "net6.0"]:
|
||||
test_dll = (
|
||||
sdk_dir
|
||||
/ "Sdks"
|
||||
/ "Microsoft.NET.Sdk"
|
||||
/ "tools"
|
||||
/ net_version
|
||||
/ "Microsoft.NET.Build.Tasks.dll"
|
||||
)
|
||||
if test_dll.exists():
|
||||
return str(test_dll)
|
||||
|
||||
# Fallback: any .dll in dotnet directory
|
||||
for root, dirs, files in os.walk("/usr/share/dotnet"):
|
||||
# Last resort: any .dll in dotnet directory
|
||||
for root, _dirs, files in os.walk("/usr/share/dotnet"):
|
||||
for f in files:
|
||||
if f.endswith(".dll"):
|
||||
return os.path.join(root, f)
|
||||
@ -39,3 +64,22 @@ def sample_assembly_path() -> str:
|
||||
def nonexistent_path() -> str:
|
||||
"""Return a path that doesn't exist."""
|
||||
return "/nonexistent/path/to/assembly.dll"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ilspycmd_installed() -> bool:
|
||||
"""Check if ilspycmd is available for integration tests."""
|
||||
return shutil.which("ilspycmd") is not None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def skip_without_ilspycmd(ilspycmd_installed):
|
||||
"""Skip test if ilspycmd is not installed."""
|
||||
if not ilspycmd_installed:
|
||||
pytest.skip("ilspycmd not installed")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_output_dir(tmp_path):
|
||||
"""Provide a temporary directory for test outputs."""
|
||||
return str(tmp_path)
|
||||
|
||||
214
tests/fixtures/TestAssembly.cs
vendored
Normal file
214
tests/fixtures/TestAssembly.cs
vendored
Normal file
@ -0,0 +1,214 @@
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace TestNamespace
|
||||
{
|
||||
/// <summary>
|
||||
/// A test class with various members for testing the mcilspy MCP server.
|
||||
/// </summary>
|
||||
public class TestClass
|
||||
{
|
||||
// Constants for testing string search
|
||||
public const string API_KEY = "test-secret-key";
|
||||
public const string BASE_URL = "https://api.example.com";
|
||||
public const int MAX_RETRIES = 3;
|
||||
|
||||
// Fields
|
||||
public static readonly string BaseUrl = "https://api.example.com";
|
||||
private int _privateField;
|
||||
protected string _protectedField;
|
||||
internal double _internalField;
|
||||
|
||||
// Properties
|
||||
public string Name { get; set; }
|
||||
public int Age { get; private set; }
|
||||
public virtual bool IsActive { get; set; }
|
||||
|
||||
// Events
|
||||
public event EventHandler OnChange;
|
||||
public event EventHandler<string> OnMessage;
|
||||
|
||||
// Constructors
|
||||
public TestClass()
|
||||
{
|
||||
Name = "Default";
|
||||
Age = 0;
|
||||
}
|
||||
|
||||
public TestClass(string name, int age)
|
||||
{
|
||||
Name = name;
|
||||
Age = age;
|
||||
}
|
||||
|
||||
// Methods
|
||||
public void DoSomething()
|
||||
{
|
||||
Console.WriteLine("Hello from DoSomething");
|
||||
OnChange?.Invoke(this, EventArgs.Empty);
|
||||
}
|
||||
|
||||
public string GetGreeting()
|
||||
{
|
||||
return $"Hello, {Name}!";
|
||||
}
|
||||
|
||||
public static int Add(int a, int b)
|
||||
{
|
||||
return a + b;
|
||||
}
|
||||
|
||||
protected virtual void OnPropertyChanged(string propertyName)
|
||||
{
|
||||
OnMessage?.Invoke(this, propertyName);
|
||||
}
|
||||
|
||||
private void PrivateMethod()
|
||||
{
|
||||
_privateField = 42;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Interface for testing interface discovery.
|
||||
/// </summary>
|
||||
public interface ITestService
|
||||
{
|
||||
void Execute();
|
||||
Task<string> ExecuteAsync();
|
||||
string ServiceName { get; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Another interface for inheritance testing.
|
||||
/// </summary>
|
||||
public interface IConfigurable
|
||||
{
|
||||
void Configure(string settings);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Struct for testing struct discovery.
|
||||
/// </summary>
|
||||
public struct TestStruct
|
||||
{
|
||||
public int Value;
|
||||
public string Label;
|
||||
|
||||
public TestStruct(int value, string label)
|
||||
{
|
||||
Value = value;
|
||||
Label = label;
|
||||
}
|
||||
|
||||
public override string ToString() => $"{Label}: {Value}";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Enum for testing enum discovery.
|
||||
/// </summary>
|
||||
public enum TestEnum
|
||||
{
|
||||
None = 0,
|
||||
First = 1,
|
||||
Second = 2,
|
||||
Third = 3
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Delegate for testing delegate discovery.
|
||||
/// </summary>
|
||||
public delegate void TestDelegate(string message);
|
||||
|
||||
/// <summary>
|
||||
/// Delegate with return type.
|
||||
/// </summary>
|
||||
public delegate bool ValidationDelegate<T>(T value);
|
||||
|
||||
/// <summary>
|
||||
/// Service implementation for testing class relationships.
|
||||
/// </summary>
|
||||
public class TestServiceImpl : ITestService, IConfigurable
|
||||
{
|
||||
private string _config;
|
||||
|
||||
public string ServiceName => "TestService";
|
||||
|
||||
public void Execute()
|
||||
{
|
||||
Console.WriteLine($"Executing with config: {_config}");
|
||||
}
|
||||
|
||||
public Task<string> ExecuteAsync()
|
||||
{
|
||||
return Task.FromResult($"Async result from {ServiceName}");
|
||||
}
|
||||
|
||||
public void Configure(string settings)
|
||||
{
|
||||
_config = settings;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Nested class for testing nested type discovery.
|
||||
/// </summary>
|
||||
public class OuterClass
|
||||
{
|
||||
public class NestedClass
|
||||
{
|
||||
public string Value { get; set; }
|
||||
}
|
||||
|
||||
private class PrivateNestedClass
|
||||
{
|
||||
public int Secret { get; set; }
|
||||
}
|
||||
|
||||
public NestedClass CreateNested() => new NestedClass();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Abstract class for testing abstract type discovery.
|
||||
/// </summary>
|
||||
public abstract class AbstractBase
|
||||
{
|
||||
public abstract void AbstractMethod();
|
||||
public virtual void VirtualMethod() { }
|
||||
|
||||
protected string BaseProperty { get; set; }
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Derived class for testing inheritance.
|
||||
/// </summary>
|
||||
public class DerivedClass : AbstractBase
|
||||
{
|
||||
public override void AbstractMethod()
|
||||
{
|
||||
Console.WriteLine("Implemented abstract method");
|
||||
}
|
||||
|
||||
public override void VirtualMethod()
|
||||
{
|
||||
base.VirtualMethod();
|
||||
Console.WriteLine("Overridden virtual method");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace TestNamespace.SubNamespace
|
||||
{
|
||||
/// <summary>
|
||||
/// Class in a sub-namespace for testing namespace filtering.
|
||||
/// </summary>
|
||||
public class SubClass
|
||||
{
|
||||
public const string CONNECTION_STRING = "Server=localhost;Database=test";
|
||||
|
||||
public void SubMethod()
|
||||
{
|
||||
Console.WriteLine("Sub namespace method");
|
||||
}
|
||||
}
|
||||
}
|
||||
BIN
tests/fixtures/TestAssembly.dll
vendored
Normal file
BIN
tests/fixtures/TestAssembly.dll
vendored
Normal file
Binary file not shown.
1
tests/integration/__init__.py
Normal file
1
tests/integration/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Integration tests that use real ilspycmd calls."""
|
||||
379
tests/integration/test_real_assembly.py
Normal file
379
tests/integration/test_real_assembly.py
Normal file
@ -0,0 +1,379 @@
|
||||
"""Integration tests using the custom TestAssembly.dll fixture.
|
||||
|
||||
These tests exercise the full stack including ilspycmd calls.
|
||||
Tests are skipped if ilspycmd is not installed.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from mcilspy.ilspy_wrapper import ILSpyWrapper
|
||||
from mcilspy.metadata_reader import MetadataReader
|
||||
from mcilspy.models import (
|
||||
DecompileRequest,
|
||||
EntityType,
|
||||
LanguageVersion,
|
||||
ListTypesRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestMetadataReaderWithTestAssembly:
|
||||
"""Test MetadataReader against our custom test assembly."""
|
||||
|
||||
def test_get_assembly_metadata(self, test_assembly_path):
|
||||
"""Test reading metadata from test assembly."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
meta = reader.get_assembly_metadata()
|
||||
|
||||
assert meta.name == "TestAssemblyProject"
|
||||
assert meta.type_count > 0
|
||||
assert meta.method_count > 0
|
||||
|
||||
def test_list_methods_finds_known_methods(self, test_assembly_path):
|
||||
"""Test that we can find methods we know exist."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
methods = reader.list_methods()
|
||||
|
||||
method_names = [m.name for m in methods]
|
||||
|
||||
# Check for methods we defined in TestClass
|
||||
assert "DoSomething" in method_names
|
||||
assert "GetGreeting" in method_names
|
||||
assert "Add" in method_names
|
||||
|
||||
def test_list_methods_with_type_filter(self, test_assembly_path):
|
||||
"""Test filtering methods by type."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
methods = reader.list_methods(type_filter="TestClass")
|
||||
|
||||
# All methods should be from types containing "TestClass"
|
||||
for method in methods:
|
||||
assert "TestClass" in method.declaring_type
|
||||
|
||||
def test_list_methods_with_namespace_filter(self, test_assembly_path):
|
||||
"""Test filtering methods by namespace."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
methods = reader.list_methods(namespace_filter="SubNamespace")
|
||||
|
||||
# Should only find methods from SubNamespace
|
||||
for method in methods:
|
||||
assert method.namespace is not None
|
||||
assert "SubNamespace" in method.namespace
|
||||
|
||||
def test_list_methods_public_only(self, test_assembly_path):
|
||||
"""Test filtering for public methods only."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
public_methods = reader.list_methods(public_only=True)
|
||||
all_methods = reader.list_methods(public_only=False)
|
||||
|
||||
# Should have fewer public methods than total
|
||||
assert len(public_methods) <= len(all_methods)
|
||||
# All returned methods should be public
|
||||
for method in public_methods:
|
||||
assert method.is_public
|
||||
|
||||
def test_list_fields_finds_known_fields(self, test_assembly_path):
|
||||
"""Test that we can find fields we defined."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
fields = reader.list_fields()
|
||||
|
||||
field_names = [f.name for f in fields]
|
||||
|
||||
# Check for constants and fields we defined
|
||||
assert "API_KEY" in field_names
|
||||
assert "BASE_URL" in field_names
|
||||
assert "MAX_RETRIES" in field_names
|
||||
|
||||
def test_list_fields_constants_only(self, test_assembly_path):
|
||||
"""Test filtering for constant fields only."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
constants = reader.list_fields(constants_only=True)
|
||||
|
||||
# All returned fields should be literals
|
||||
for field in constants:
|
||||
assert field.is_literal
|
||||
|
||||
const_names = [f.name for f in constants]
|
||||
assert "API_KEY" in const_names
|
||||
assert "MAX_RETRIES" in const_names
|
||||
|
||||
def test_list_properties_finds_known_properties(self, test_assembly_path):
|
||||
"""Test that we can find properties we defined."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
properties = reader.list_properties()
|
||||
|
||||
prop_names = [p.name for p in properties]
|
||||
|
||||
# Check for properties we defined
|
||||
assert "Name" in prop_names
|
||||
assert "Age" in prop_names
|
||||
assert "IsActive" in prop_names
|
||||
assert "ServiceName" in prop_names
|
||||
|
||||
def test_list_events_finds_known_events(self, test_assembly_path):
|
||||
"""Test that we can find events we defined."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
events = reader.list_events()
|
||||
|
||||
event_names = [e.name for e in events]
|
||||
|
||||
# Check for events we defined
|
||||
assert "OnChange" in event_names
|
||||
assert "OnMessage" in event_names
|
||||
|
||||
def test_list_resources_empty_for_test_assembly(self, test_assembly_path):
|
||||
"""Test that test assembly has no embedded resources."""
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
resources = reader.list_resources()
|
||||
|
||||
# Our simple test assembly has no resources
|
||||
assert isinstance(resources, list)
|
||||
|
||||
|
||||
class TestILSpyWrapperWithTestAssembly:
|
||||
"""Integration tests for ILSpyWrapper using real ilspycmd calls."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self, skip_without_ilspycmd):
|
||||
"""Get wrapper instance, skipping if ilspycmd not available."""
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_finds_classes(self, wrapper, test_assembly_path):
|
||||
"""Test listing classes from test assembly."""
|
||||
request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.CLASS],
|
||||
)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success
|
||||
assert response.total_count > 0
|
||||
|
||||
type_names = [t.name for t in response.types]
|
||||
assert "TestClass" in type_names
|
||||
assert "TestServiceImpl" in type_names
|
||||
assert "OuterClass" in type_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_finds_interfaces(self, wrapper, test_assembly_path):
|
||||
"""Test listing interfaces from test assembly."""
|
||||
request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.INTERFACE],
|
||||
)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success
|
||||
|
||||
type_names = [t.name for t in response.types]
|
||||
assert "ITestService" in type_names
|
||||
assert "IConfigurable" in type_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_finds_structs(self, wrapper, test_assembly_path):
|
||||
"""Test listing structs from test assembly."""
|
||||
request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.STRUCT],
|
||||
)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success
|
||||
|
||||
type_names = [t.name for t in response.types]
|
||||
assert "TestStruct" in type_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_finds_enums(self, wrapper, test_assembly_path):
|
||||
"""Test listing enums from test assembly."""
|
||||
request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.ENUM],
|
||||
)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success
|
||||
|
||||
type_names = [t.name for t in response.types]
|
||||
assert "TestEnum" in type_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_finds_delegates(self, wrapper, test_assembly_path):
|
||||
"""Test listing delegates from test assembly."""
|
||||
request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.DELEGATE],
|
||||
)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success
|
||||
|
||||
type_names = [t.name for t in response.types]
|
||||
assert "TestDelegate" in type_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_specific_type(self, wrapper, test_assembly_path):
|
||||
"""Test decompiling a specific type."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
type_name="TestNamespace.TestClass",
|
||||
language_version=LanguageVersion.LATEST,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
assert response.source_code is not None
|
||||
|
||||
# Check that decompiled code contains expected elements
|
||||
source = response.source_code
|
||||
assert "class TestClass" in source
|
||||
assert "DoSomething" in source
|
||||
assert "GetGreeting" in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_entire_assembly(self, wrapper, test_assembly_path):
|
||||
"""Test decompiling the entire assembly."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
language_version=LanguageVersion.LATEST,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
assert response.source_code is not None
|
||||
|
||||
# Check that all types are present
|
||||
source = response.source_code
|
||||
assert "TestClass" in source
|
||||
assert "ITestService" in source
|
||||
assert "TestStruct" in source
|
||||
assert "TestEnum" in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_to_il(self, wrapper, test_assembly_path):
|
||||
"""Test decompiling to IL code."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
type_name="TestNamespace.TestClass",
|
||||
show_il_code=True,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
assert response.source_code is not None
|
||||
|
||||
# IL code should contain IL-specific keywords
|
||||
source = response.source_code
|
||||
# IL typically shows .method, .field, etc.
|
||||
assert ".class" in source or "IL_" in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_to_output_dir(self, wrapper, test_assembly_path, temp_output_dir):
|
||||
"""Test decompiling to an output directory."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
output_dir=temp_output_dir,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
assert response.output_path is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_with_project_structure(
|
||||
self, wrapper, test_assembly_path, temp_output_dir
|
||||
):
|
||||
"""Test decompiling with project structure."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
output_dir=temp_output_dir,
|
||||
create_project=True,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_nonexistent_type(self, wrapper, test_assembly_path):
|
||||
"""Test decompiling a type that doesn't exist."""
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
type_name="NonExistent.FakeClass",
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
# Should still succeed but with empty or no matching output
|
||||
# The actual behavior depends on ilspycmd version
|
||||
assert response is not None
|
||||
|
||||
|
||||
class TestIntegrationEndToEnd:
|
||||
"""End-to-end integration tests covering complete workflows."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discover_and_decompile_workflow(
|
||||
self, skip_without_ilspycmd, test_assembly_path
|
||||
):
|
||||
"""Test the typical workflow: list types, then decompile specific one."""
|
||||
wrapper = ILSpyWrapper()
|
||||
|
||||
# Step 1: List all types
|
||||
list_request = ListTypesRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
entity_types=[EntityType.CLASS],
|
||||
)
|
||||
list_response = await wrapper.list_types(list_request)
|
||||
|
||||
assert list_response.success
|
||||
assert len(list_response.types) > 0
|
||||
|
||||
# Step 2: Find TestServiceImpl
|
||||
service_type = None
|
||||
for t in list_response.types:
|
||||
if t.name == "TestServiceImpl":
|
||||
service_type = t
|
||||
break
|
||||
|
||||
assert service_type is not None
|
||||
|
||||
# Step 3: Decompile it
|
||||
decompile_request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
type_name=service_type.full_name,
|
||||
)
|
||||
decompile_response = await wrapper.decompile(decompile_request)
|
||||
|
||||
assert decompile_response.success
|
||||
assert decompile_response.source_code is not None
|
||||
assert "TestServiceImpl" in decompile_response.source_code
|
||||
assert "ITestService" in decompile_response.source_code
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_metadata_and_decompile_combined(
|
||||
self, skip_without_ilspycmd, test_assembly_path
|
||||
):
|
||||
"""Test using metadata reader and ILSpy wrapper together."""
|
||||
# Use metadata reader for quick discovery
|
||||
with MetadataReader(test_assembly_path) as reader:
|
||||
methods = reader.list_methods(type_filter="TestClass")
|
||||
add_method = None
|
||||
for m in methods:
|
||||
if m.name == "Add":
|
||||
add_method = m
|
||||
break
|
||||
|
||||
assert add_method is not None
|
||||
assert add_method.is_static
|
||||
|
||||
# Use ILSpy for decompilation
|
||||
wrapper = ILSpyWrapper()
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
type_name="TestNamespace.TestClass",
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success
|
||||
# Verify the static method is in the output
|
||||
assert "static" in response.source_code
|
||||
assert "Add" in response.source_code
|
||||
285
tests/test_concurrency.py
Normal file
285
tests/test_concurrency.py
Normal file
@ -0,0 +1,285 @@
|
||||
"""Tests for concurrent tool invocations.
|
||||
|
||||
These tests verify that the server handles multiple simultaneous
|
||||
tool calls correctly using asyncio.gather().
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from mcilspy import server
|
||||
from mcilspy.metadata_reader import MetadataReader
|
||||
|
||||
|
||||
class TestConcurrentMetadataOperations:
|
||||
"""Test concurrent metadata reading operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_search_methods(self, test_assembly_path):
|
||||
"""Test multiple search_methods calls running concurrently."""
|
||||
patterns = ["Get", "Do", "Set", "Add", "Create"]
|
||||
|
||||
tasks = [
|
||||
server.search_methods(test_assembly_path, pattern=p) for p in patterns
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All tasks should complete successfully
|
||||
assert len(results) == len(patterns)
|
||||
|
||||
# Each result should be a string
|
||||
for result in results:
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_search_fields(self, test_assembly_path):
|
||||
"""Test multiple search_fields calls running concurrently."""
|
||||
patterns = ["API", "URL", "MAX", "BASE", "VALUE"]
|
||||
|
||||
tasks = [
|
||||
server.search_fields(test_assembly_path, pattern=p) for p in patterns
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == len(patterns)
|
||||
for result in results:
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_search_properties(self, test_assembly_path):
|
||||
"""Test multiple search_properties calls running concurrently."""
|
||||
patterns = ["Name", "Value", "Is", "Service"]
|
||||
|
||||
tasks = [
|
||||
server.search_properties(test_assembly_path, pattern=p) for p in patterns
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == len(patterns)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_mixed_operations(self, test_assembly_path):
|
||||
"""Test different metadata operations running concurrently."""
|
||||
tasks = [
|
||||
server.search_methods(test_assembly_path, pattern="Get"),
|
||||
server.search_fields(test_assembly_path, pattern="API"),
|
||||
server.search_properties(test_assembly_path, pattern="Name"),
|
||||
server.list_events(test_assembly_path),
|
||||
server.list_resources(test_assembly_path),
|
||||
server.get_metadata_summary(test_assembly_path),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 6
|
||||
for result in results:
|
||||
assert isinstance(result, str)
|
||||
# None of them should have crashed
|
||||
assert "Traceback" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_same_assembly_multiple_readers(self, test_assembly_path):
|
||||
"""Test multiple MetadataReaders on the same assembly."""
|
||||
async def read_metadata(path):
|
||||
"""Async wrapper for metadata reading."""
|
||||
with MetadataReader(path) as reader:
|
||||
return reader.get_assembly_metadata()
|
||||
|
||||
# Run multiple readers concurrently
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = [
|
||||
loop.run_in_executor(None, lambda: MetadataReader(test_assembly_path).__enter__().get_assembly_metadata())
|
||||
for _ in range(5)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 5
|
||||
# All results should have the same assembly name
|
||||
names = [r.name for r in results]
|
||||
assert all(n == names[0] for n in names)
|
||||
|
||||
|
||||
class TestConcurrentToolCalls:
|
||||
"""Test concurrent MCP tool invocations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_high_concurrency_search(self, test_assembly_path):
|
||||
"""Test high number of concurrent searches."""
|
||||
num_concurrent = 20
|
||||
|
||||
tasks = [
|
||||
server.search_methods(test_assembly_path, pattern=f"pattern{i}")
|
||||
for i in range(num_concurrent)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == num_concurrent
|
||||
# Most should return "No methods found" but shouldn't crash
|
||||
for result in results:
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_with_errors(self, test_assembly_path, nonexistent_path):
|
||||
"""Test concurrent calls where some will fail."""
|
||||
tasks = [
|
||||
# These should succeed
|
||||
server.search_methods(test_assembly_path, pattern="Get"),
|
||||
server.search_fields(test_assembly_path, pattern="API"),
|
||||
# These should fail gracefully
|
||||
server.search_methods(nonexistent_path, pattern="test"),
|
||||
server.search_fields(nonexistent_path, pattern="test"),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
assert len(results) == 4
|
||||
|
||||
# First two should be successful results
|
||||
assert "GetGreeting" in results[0] or "No methods" in results[0]
|
||||
assert "API_KEY" in results[1] or "No fields" in results[1]
|
||||
|
||||
# Last two should have error messages
|
||||
assert "Error" in results[2]
|
||||
assert "Error" in results[3]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_list_operations(self, test_assembly_path):
|
||||
"""Test concurrent list operations."""
|
||||
tasks = [
|
||||
server.list_events(test_assembly_path),
|
||||
server.list_events(test_assembly_path),
|
||||
server.list_resources(test_assembly_path),
|
||||
server.list_resources(test_assembly_path),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 4
|
||||
# Event results should be identical
|
||||
assert results[0] == results[1]
|
||||
# Resource results should be identical
|
||||
assert results[2] == results[3]
|
||||
|
||||
|
||||
class TestConcurrentWithRegex:
|
||||
"""Test concurrent operations with regex patterns."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_regex_searches(self, test_assembly_path):
|
||||
"""Test concurrent regex pattern searches."""
|
||||
patterns = [r"^Get.*", r".*Service$", r"On\w+", r".*Base.*"]
|
||||
|
||||
tasks = [
|
||||
server.search_methods(test_assembly_path, pattern=p, use_regex=True)
|
||||
for p in patterns
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 4
|
||||
for result in results:
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_invalid_regex(self, test_assembly_path):
|
||||
"""Test that concurrent invalid regex patterns are handled safely."""
|
||||
patterns = [
|
||||
r"[invalid", # Invalid
|
||||
r"valid.*", # Valid
|
||||
r"(?P<broken", # Invalid
|
||||
r"also.*valid", # Valid
|
||||
]
|
||||
|
||||
tasks = [
|
||||
server.search_methods(test_assembly_path, pattern=p, use_regex=True)
|
||||
for p in patterns
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 4
|
||||
# Invalid patterns should return error messages
|
||||
assert "Invalid regex" in results[0]
|
||||
assert "Invalid regex" in results[2]
|
||||
# Valid patterns should return results (even if empty)
|
||||
assert "Invalid regex" not in results[1]
|
||||
assert "Invalid regex" not in results[3]
|
||||
|
||||
|
||||
class TestConcurrentNamespaceFiltering:
|
||||
"""Test concurrent operations with namespace filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_namespace_filtered_searches(self, test_assembly_path):
|
||||
"""Test concurrent searches with different namespace filters."""
|
||||
namespace_filters = [
|
||||
"TestNamespace",
|
||||
"SubNamespace",
|
||||
"NonExistent",
|
||||
None, # No filter
|
||||
]
|
||||
|
||||
tasks = [
|
||||
server.search_methods(
|
||||
test_assembly_path, pattern="", namespace_filter=ns
|
||||
)
|
||||
for ns in namespace_filters
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 4
|
||||
|
||||
|
||||
class TestConcurrencyIsolation:
|
||||
"""Test that concurrent operations don't interfere with each other."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_results_not_mixed(self, test_assembly_path):
|
||||
"""Verify that concurrent operations return their own results."""
|
||||
# Use very specific patterns that should match different things
|
||||
async def search_and_verify(pattern, expected):
|
||||
result = await server.search_methods(test_assembly_path, pattern=pattern)
|
||||
if expected:
|
||||
assert expected in result, f"Expected '{expected}' in result for pattern '{pattern}'"
|
||||
return result
|
||||
|
||||
tasks = [
|
||||
search_and_verify("DoSomething", "DoSomething"),
|
||||
search_and_verify("GetGreeting", "GetGreeting"),
|
||||
search_and_verify("Add", "Add"),
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify each result contains only its expected match
|
||||
assert "DoSomething" in results[0]
|
||||
assert "GetGreeting" in results[1]
|
||||
assert "Add" in results[2]
|
||||
|
||||
# Results shouldn't be mixed up
|
||||
assert results[0] != results[1]
|
||||
assert results[1] != results[2]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_metadata_summary_consistent(self, test_assembly_path):
|
||||
"""Verify concurrent metadata summary calls return consistent results."""
|
||||
num_concurrent = 10
|
||||
|
||||
tasks = [
|
||||
server.get_metadata_summary(test_assembly_path)
|
||||
for _ in range(num_concurrent)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# All results should be identical
|
||||
first_result = results[0]
|
||||
for result in results[1:]:
|
||||
assert result == first_result, "Concurrent metadata calls returned different results"
|
||||
302
tests/test_docstrings.py
Normal file
302
tests/test_docstrings.py
Normal file
@ -0,0 +1,302 @@
|
||||
"""Tests for docstring coverage.
|
||||
|
||||
Verifies that all public functions and classes have docstrings.
|
||||
Uses AST to introspect the source code.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
import mcilspy.ilspy_wrapper as wrapper_module
|
||||
import mcilspy.metadata_reader as reader_module
|
||||
import mcilspy.models as models_module
|
||||
import mcilspy.utils as utils_module
|
||||
|
||||
# Import the modules we want to check
|
||||
import mcilspy.server as server_module
|
||||
|
||||
|
||||
def get_public_functions_and_classes(module):
|
||||
"""Get all public functions and classes from a module.
|
||||
|
||||
Returns a list of (name, obj, has_docstring) tuples.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for name in dir(module):
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
|
||||
obj = getattr(module, name)
|
||||
|
||||
# Check if it's a function or class defined in this module
|
||||
if not (inspect.isfunction(obj) or inspect.isclass(obj)):
|
||||
continue
|
||||
|
||||
# Skip imported items
|
||||
if hasattr(obj, "__module__") and obj.__module__ != module.__name__:
|
||||
continue
|
||||
|
||||
has_docstring = bool(inspect.getdoc(obj))
|
||||
results.append((name, obj, has_docstring))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def get_public_methods(cls):
|
||||
"""Get all public methods from a class.
|
||||
|
||||
Returns a list of (name, method, has_docstring) tuples.
|
||||
"""
|
||||
results = []
|
||||
|
||||
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
if name.startswith("_") and not name.startswith("__"):
|
||||
continue
|
||||
if name.startswith("__") and name != "__init__":
|
||||
continue
|
||||
|
||||
has_docstring = bool(inspect.getdoc(method))
|
||||
results.append((name, method, has_docstring))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TestServerModuleDocstrings:
|
||||
"""Tests for server.py docstring coverage."""
|
||||
|
||||
def test_all_mcp_tools_have_docstrings(self):
|
||||
"""Verify all @mcp.tool() decorated functions have docstrings."""
|
||||
# Find all functions decorated with @mcp.tool()
|
||||
# These are the public API and MUST have docstrings
|
||||
tools = [
|
||||
server_module.check_ilspy_installation,
|
||||
server_module.install_ilspy,
|
||||
server_module.decompile_assembly,
|
||||
server_module.list_types,
|
||||
server_module.generate_diagrammer,
|
||||
server_module.get_assembly_info,
|
||||
server_module.search_types,
|
||||
server_module.search_strings,
|
||||
server_module.search_methods,
|
||||
server_module.search_fields,
|
||||
server_module.search_properties,
|
||||
server_module.list_events,
|
||||
server_module.list_resources,
|
||||
server_module.get_metadata_summary,
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for tool in tools:
|
||||
docstring = inspect.getdoc(tool)
|
||||
if not docstring:
|
||||
missing_docstrings.append(tool.__name__)
|
||||
|
||||
assert not missing_docstrings, f"Tools missing docstrings: {missing_docstrings}"
|
||||
|
||||
def test_tool_docstrings_have_args_section(self):
|
||||
"""Verify tool docstrings document their arguments."""
|
||||
# Tools with parameters should have Args: section
|
||||
tools_with_params = [
|
||||
server_module.decompile_assembly,
|
||||
server_module.list_types,
|
||||
server_module.generate_diagrammer,
|
||||
server_module.get_assembly_info,
|
||||
server_module.search_types,
|
||||
server_module.search_strings,
|
||||
server_module.search_methods,
|
||||
server_module.search_fields,
|
||||
server_module.search_properties,
|
||||
server_module.list_events,
|
||||
server_module.list_resources,
|
||||
server_module.get_metadata_summary,
|
||||
]
|
||||
|
||||
missing_args = []
|
||||
for tool in tools_with_params:
|
||||
docstring = inspect.getdoc(tool)
|
||||
sig = inspect.signature(tool)
|
||||
|
||||
# Get non-ctx parameters
|
||||
params = [
|
||||
p
|
||||
for p in sig.parameters.values()
|
||||
if p.name != "ctx" and p.name != "self"
|
||||
]
|
||||
|
||||
if params and docstring and "Args:" not in docstring:
|
||||
missing_args.append(tool.__name__)
|
||||
|
||||
assert not missing_args, f"Tools missing Args section: {missing_args}"
|
||||
|
||||
def test_helper_functions_have_docstrings(self):
|
||||
"""Verify helper functions have docstrings."""
|
||||
helpers = [
|
||||
server_module.get_wrapper,
|
||||
server_module._format_error,
|
||||
utils_module.find_ilspycmd_path, # Moved to utils
|
||||
server_module._check_dotnet_tools,
|
||||
server_module._detect_platform,
|
||||
server_module._try_install_dotnet_sdk,
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for helper in helpers:
|
||||
docstring = inspect.getdoc(helper)
|
||||
if not docstring:
|
||||
missing_docstrings.append(helper.__name__)
|
||||
|
||||
assert not missing_docstrings, f"Helpers missing docstrings: {missing_docstrings}"
|
||||
|
||||
|
||||
class TestWrapperModuleDocstrings:
|
||||
"""Tests for ilspy_wrapper.py docstring coverage."""
|
||||
|
||||
def test_wrapper_class_has_docstring(self):
|
||||
"""Verify ILSpyWrapper class has a docstring."""
|
||||
docstring = inspect.getdoc(wrapper_module.ILSpyWrapper)
|
||||
assert docstring, "ILSpyWrapper class should have a docstring"
|
||||
|
||||
def test_wrapper_public_methods_have_docstrings(self):
|
||||
"""Verify ILSpyWrapper public methods have docstrings."""
|
||||
methods_to_check = [
|
||||
"decompile",
|
||||
"list_types",
|
||||
"generate_diagrammer",
|
||||
"get_assembly_info",
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for method_name in methods_to_check:
|
||||
method = getattr(wrapper_module.ILSpyWrapper, method_name, None)
|
||||
if method:
|
||||
docstring = inspect.getdoc(method)
|
||||
if not docstring:
|
||||
missing_docstrings.append(method_name)
|
||||
|
||||
assert not missing_docstrings, (
|
||||
f"ILSpyWrapper methods missing docstrings: {missing_docstrings}"
|
||||
)
|
||||
|
||||
|
||||
class TestMetadataReaderDocstrings:
|
||||
"""Tests for metadata_reader.py docstring coverage."""
|
||||
|
||||
def test_reader_class_has_docstring(self):
|
||||
"""Verify MetadataReader class has a docstring."""
|
||||
docstring = inspect.getdoc(reader_module.MetadataReader)
|
||||
assert docstring, "MetadataReader class should have a docstring"
|
||||
|
||||
def test_reader_public_methods_have_docstrings(self):
|
||||
"""Verify MetadataReader public methods have docstrings."""
|
||||
methods_to_check = [
|
||||
"get_assembly_metadata",
|
||||
"list_methods",
|
||||
"list_fields",
|
||||
"list_properties",
|
||||
"list_events",
|
||||
"list_resources",
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for method_name in methods_to_check:
|
||||
method = getattr(reader_module.MetadataReader, method_name, None)
|
||||
if method:
|
||||
docstring = inspect.getdoc(method)
|
||||
if not docstring:
|
||||
missing_docstrings.append(method_name)
|
||||
|
||||
assert not missing_docstrings, (
|
||||
f"MetadataReader methods missing docstrings: {missing_docstrings}"
|
||||
)
|
||||
|
||||
|
||||
class TestModelsDocstrings:
|
||||
"""Tests for models.py docstring coverage."""
|
||||
|
||||
def test_pydantic_models_have_docstrings(self):
|
||||
"""Verify Pydantic model classes have docstrings."""
|
||||
models_to_check = [
|
||||
models_module.DecompileRequest,
|
||||
models_module.DecompileResponse,
|
||||
models_module.ListTypesRequest,
|
||||
models_module.ListTypesResponse,
|
||||
models_module.TypeInfo,
|
||||
models_module.AssemblyInfo,
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for model in models_to_check:
|
||||
docstring = inspect.getdoc(model)
|
||||
if not docstring:
|
||||
missing_docstrings.append(model.__name__)
|
||||
|
||||
# Just check that most have docstrings - Pydantic models are self-documenting
|
||||
# through their field names
|
||||
assert len(missing_docstrings) <= 2, (
|
||||
f"Too many models missing docstrings: {missing_docstrings}"
|
||||
)
|
||||
|
||||
|
||||
class TestModuleDocstrings:
|
||||
"""Tests for module-level docstrings."""
|
||||
|
||||
def test_all_modules_have_docstrings(self):
|
||||
"""Verify all mcilspy modules have module-level docstrings."""
|
||||
modules = [
|
||||
server_module,
|
||||
wrapper_module,
|
||||
reader_module,
|
||||
models_module,
|
||||
]
|
||||
|
||||
missing_docstrings = []
|
||||
for module in modules:
|
||||
if not module.__doc__:
|
||||
missing_docstrings.append(module.__name__)
|
||||
|
||||
# Just warn, don't fail - module docstrings are nice but not critical
|
||||
if missing_docstrings:
|
||||
pytest.skip(f"Modules missing docstrings (non-critical): {missing_docstrings}")
|
||||
|
||||
|
||||
class TestDocstringQuality:
|
||||
"""Tests for docstring quality (not just presence)."""
|
||||
|
||||
def test_tool_docstrings_not_empty(self):
|
||||
"""Verify tool docstrings have meaningful content."""
|
||||
tools = [
|
||||
server_module.decompile_assembly,
|
||||
server_module.list_types,
|
||||
server_module.search_methods,
|
||||
]
|
||||
|
||||
short_docstrings = []
|
||||
for tool in tools:
|
||||
docstring = inspect.getdoc(tool)
|
||||
if docstring and len(docstring) < 50:
|
||||
short_docstrings.append(f"{tool.__name__}: {len(docstring)} chars")
|
||||
|
||||
assert not short_docstrings, (
|
||||
f"Tools have too-short docstrings: {short_docstrings}"
|
||||
)
|
||||
|
||||
def test_docstrings_describe_purpose(self):
|
||||
"""Verify key tool docstrings describe what the tool does."""
|
||||
key_words = {
|
||||
server_module.decompile_assembly: ["decompile", "assembly", "C#"],
|
||||
server_module.list_types: ["types", "list", "class"],
|
||||
server_module.search_methods: ["search", "method"],
|
||||
}
|
||||
|
||||
missing_keywords = []
|
||||
for tool, keywords in key_words.items():
|
||||
docstring = inspect.getdoc(tool).lower() if inspect.getdoc(tool) else ""
|
||||
for keyword in keywords:
|
||||
if keyword.lower() not in docstring:
|
||||
missing_keywords.append(f"{tool.__name__} missing '{keyword}'")
|
||||
|
||||
assert not missing_keywords, (
|
||||
f"Docstrings missing expected keywords: {missing_keywords}"
|
||||
)
|
||||
425
tests/test_error_paths.py
Normal file
425
tests/test_error_paths.py
Normal file
@ -0,0 +1,425 @@
|
||||
"""Tests for error handling paths.
|
||||
|
||||
These tests verify that the server handles various error conditions gracefully:
|
||||
- Invalid regex patterns
|
||||
- ilspycmd not found scenarios
|
||||
- Invalid language versions
|
||||
- File not found errors
|
||||
- Invalid assembly files
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcilspy import server
|
||||
from mcilspy.ilspy_wrapper import ILSpyWrapper
|
||||
from mcilspy.metadata_reader import MetadataReader
|
||||
from mcilspy.models import EntityType
|
||||
|
||||
|
||||
# Fixture to bypass path validation for tests using mock paths
|
||||
@pytest.fixture
|
||||
def bypass_path_validation():
|
||||
"""Bypass _validate_assembly_path for tests using mock wrapper."""
|
||||
def passthrough(path):
|
||||
return path
|
||||
with patch.object(server, "_validate_assembly_path", side_effect=passthrough):
|
||||
yield
|
||||
|
||||
|
||||
class TestInvalidRegexPatterns:
|
||||
"""Tests for invalid regex pattern handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_types_invalid_regex(self, test_assembly_path):
|
||||
"""Test search_types with invalid regex pattern."""
|
||||
# Use an invalid regex pattern
|
||||
result = await server.search_types(
|
||||
test_assembly_path,
|
||||
pattern="[invalid(regex",
|
||||
use_regex=True,
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_invalid_regex(self, test_assembly_path):
|
||||
"""Test search_methods with invalid regex pattern."""
|
||||
result = await server.search_methods(
|
||||
test_assembly_path,
|
||||
pattern="[unclosed",
|
||||
use_regex=True,
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fields_invalid_regex(self, test_assembly_path):
|
||||
"""Test search_fields with invalid regex pattern."""
|
||||
result = await server.search_fields(
|
||||
test_assembly_path,
|
||||
pattern="*invalid*",
|
||||
use_regex=True,
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_properties_invalid_regex(self, test_assembly_path):
|
||||
"""Test search_properties with invalid regex pattern."""
|
||||
result = await server.search_properties(
|
||||
test_assembly_path,
|
||||
pattern="(?P<broken",
|
||||
use_regex=True,
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_strings_invalid_regex(self, test_assembly_path):
|
||||
"""Test search_strings with invalid regex pattern."""
|
||||
# Now uses fast MetadataReader search - no wrapper needed
|
||||
result = await server.search_strings(
|
||||
test_assembly_path,
|
||||
pattern="[broken",
|
||||
use_regex=True,
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestIlspyCmdNotFound:
|
||||
"""Tests for scenarios where ilspycmd is not installed."""
|
||||
|
||||
def test_wrapper_init_raises_when_not_found(self):
|
||||
"""Test that ILSpyWrapper raises RuntimeError when ilspycmd not found."""
|
||||
with (
|
||||
patch("shutil.which", return_value=None),
|
||||
patch("os.path.isfile", return_value=False),
|
||||
pytest.raises(RuntimeError) as exc_info,
|
||||
):
|
||||
ILSpyWrapper()
|
||||
|
||||
assert "ILSpyCmd not found" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_when_not_installed(self):
|
||||
"""Test decompile_assembly when ilspycmd is not installed."""
|
||||
with patch.object(
|
||||
server, "get_wrapper", side_effect=RuntimeError("ILSpyCmd not found")
|
||||
):
|
||||
result = await server.decompile_assembly("/path/to/test.dll")
|
||||
|
||||
assert "Error" in result
|
||||
assert "ILSpyCmd not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_when_not_installed(self):
|
||||
"""Test list_types when ilspycmd is not installed."""
|
||||
with patch.object(
|
||||
server, "get_wrapper", side_effect=RuntimeError("ILSpyCmd not found")
|
||||
):
|
||||
result = await server.list_types("/path/to/test.dll")
|
||||
|
||||
assert "Error" in result
|
||||
assert "ILSpyCmd not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_diagrammer_when_not_installed(self):
|
||||
"""Test generate_diagrammer when ilspycmd is not installed."""
|
||||
with patch.object(
|
||||
server, "get_wrapper", side_effect=RuntimeError("ILSpyCmd not found")
|
||||
):
|
||||
result = await server.generate_diagrammer("/path/to/test.dll")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembly_info_when_not_installed(self):
|
||||
"""Test get_assembly_info when ilspycmd is not installed."""
|
||||
with patch.object(
|
||||
server, "get_wrapper", side_effect=RuntimeError("ILSpyCmd not found")
|
||||
):
|
||||
result = await server.get_assembly_info("/path/to/test.dll")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestInvalidLanguageVersion:
|
||||
"""Tests for invalid language version handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_with_invalid_language_version(self):
|
||||
"""Test decompile_assembly with invalid language version."""
|
||||
# The LanguageVersion enum should raise ValueError for invalid versions
|
||||
result = await server.decompile_assembly(
|
||||
"/path/to/test.dll",
|
||||
language_version="CSharp99", # Invalid version
|
||||
)
|
||||
|
||||
# Should return an error about the invalid language version
|
||||
assert "Invalid language version" in result or "Error" in result
|
||||
|
||||
|
||||
class TestFileNotFoundErrors:
|
||||
"""Tests for file not found error handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_file_not_found(self, nonexistent_path):
|
||||
"""Test search_methods with nonexistent file."""
|
||||
result = await server.search_methods(nonexistent_path, pattern="test")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fields_file_not_found(self, nonexistent_path):
|
||||
"""Test search_fields with nonexistent file."""
|
||||
result = await server.search_fields(nonexistent_path, pattern="test")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_properties_file_not_found(self, nonexistent_path):
|
||||
"""Test search_properties with nonexistent file."""
|
||||
result = await server.search_properties(nonexistent_path, pattern="test")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events_file_not_found(self, nonexistent_path):
|
||||
"""Test list_events with nonexistent file."""
|
||||
result = await server.list_events(nonexistent_path)
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_resources_file_not_found(self, nonexistent_path):
|
||||
"""Test list_resources with nonexistent file."""
|
||||
result = await server.list_resources(nonexistent_path)
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_summary_file_not_found(self, nonexistent_path):
|
||||
"""Test get_metadata_summary with nonexistent file."""
|
||||
result = await server.get_metadata_summary(nonexistent_path)
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
def test_metadata_reader_file_not_found(self, nonexistent_path):
|
||||
"""Test MetadataReader with nonexistent file."""
|
||||
with pytest.raises(FileNotFoundError):
|
||||
MetadataReader(nonexistent_path)
|
||||
|
||||
|
||||
class TestInvalidAssemblyFiles:
|
||||
"""Tests for handling invalid assembly files."""
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_assembly_path(self, tmp_path):
|
||||
"""Create a file that is not a valid .NET assembly."""
|
||||
invalid_file = tmp_path / "invalid.dll"
|
||||
invalid_file.write_text("This is not a valid PE file")
|
||||
return str(invalid_file)
|
||||
|
||||
def test_metadata_reader_invalid_assembly(self, invalid_assembly_path):
|
||||
"""Test MetadataReader with an invalid assembly file."""
|
||||
# dnfile may silently fail or raise on invalid assemblies
|
||||
# Either outcome is acceptable - the key is it doesn't crash
|
||||
try:
|
||||
with MetadataReader(invalid_assembly_path) as reader:
|
||||
# If it opens, trying to read should fail or return empty
|
||||
reader.get_assembly_metadata()
|
||||
# If we get here, that's OK - just shouldn't crash
|
||||
assert True
|
||||
except Exception:
|
||||
# An exception is also acceptable for invalid PE files
|
||||
assert True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_invalid_assembly(self, invalid_assembly_path):
|
||||
"""Test search_methods with invalid assembly."""
|
||||
result = await server.search_methods(invalid_assembly_path, pattern="test")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_summary_invalid_assembly(self, invalid_assembly_path):
|
||||
"""Test get_metadata_summary with invalid assembly."""
|
||||
result = await server.get_metadata_summary(invalid_assembly_path)
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
class TestEntityTypeValidation:
|
||||
"""Tests for EntityType enum validation."""
|
||||
|
||||
def test_invalid_entity_type_string(self):
|
||||
"""Test EntityType.from_string with invalid type name."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
EntityType.from_string("invalid_type")
|
||||
|
||||
assert "Invalid entity type" in str(exc_info.value)
|
||||
|
||||
def test_invalid_entity_type_single_letter(self):
|
||||
"""Test EntityType.from_string with invalid single letter."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
EntityType.from_string("x")
|
||||
|
||||
assert "Invalid entity type" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_with_invalid_entity_type(self):
|
||||
"""Test list_types with invalid entity type in list."""
|
||||
# The server should skip invalid entity types with a warning
|
||||
from mcilspy.models import ListTypesResponse
|
||||
|
||||
mock_response = ListTypesResponse(success=True, types=[], total_count=0)
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
# Should not raise, but should skip invalid types
|
||||
result = await server.list_types(
|
||||
"/path/to/test.dll",
|
||||
entity_types=["class", "invalid", "interface"],
|
||||
)
|
||||
|
||||
# Should still work, just skipping the invalid type
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestContextInfoFailure:
|
||||
"""Tests for handling ctx.info() failures."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_with_failing_context(self):
|
||||
"""Test decompile_assembly when ctx.info() fails."""
|
||||
from mcilspy.models import DecompileResponse
|
||||
|
||||
mock_response = DecompileResponse(
|
||||
success=True,
|
||||
assembly_name="Test",
|
||||
source_code="class Test { }",
|
||||
)
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Create a mock context that fails on info()
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.info = AsyncMock(side_effect=Exception("Context info failed"))
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
# The function should handle the failing ctx.info gracefully
|
||||
# Note: Currently ctx is optional and None by default
|
||||
result = await server.decompile_assembly("/path/to/test.dll", ctx=None)
|
||||
|
||||
# Should still succeed since ctx is optional
|
||||
assert "Test" in result
|
||||
|
||||
|
||||
class TestEmptyResults:
|
||||
"""Tests for handling empty result sets."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_empty_assembly(self, test_assembly_path):
|
||||
"""Test search with pattern that matches nothing."""
|
||||
result = await server.search_methods(
|
||||
test_assembly_path, pattern="ZZZZNONEXISTENT"
|
||||
)
|
||||
|
||||
assert "No methods found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fields_no_matches(self, test_assembly_path):
|
||||
"""Test field search with no matches."""
|
||||
result = await server.search_fields(
|
||||
test_assembly_path, pattern="NONEXISTENT_FIELD_12345"
|
||||
)
|
||||
|
||||
assert "No fields found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_properties_no_matches(self, test_assembly_path):
|
||||
"""Test property search with no matches."""
|
||||
result = await server.search_properties(
|
||||
test_assembly_path, pattern="NONEXISTENT_PROPERTY"
|
||||
)
|
||||
|
||||
assert "No properties found" in result
|
||||
|
||||
|
||||
class TestInstallIlspy:
|
||||
"""Tests for install_ilspy tool error paths."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_without_dotnet(self):
|
||||
"""Test install_ilspy when dotnet is not available."""
|
||||
mock_status = {
|
||||
"dotnet_available": False,
|
||||
"dotnet_version": None,
|
||||
"ilspycmd_available": False,
|
||||
"ilspycmd_version": None,
|
||||
"ilspycmd_path": None,
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(server, "_check_dotnet_tools", return_value=mock_status),
|
||||
patch.object(
|
||||
server,
|
||||
"_detect_platform",
|
||||
return_value={
|
||||
"system": "linux",
|
||||
"distro": "arch",
|
||||
"package_manager": "pacman",
|
||||
"install_command": "sudo pacman -S dotnet-sdk",
|
||||
},
|
||||
),
|
||||
):
|
||||
result = await server.install_ilspy()
|
||||
|
||||
assert "dotnet CLI is not installed" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_already_installed(self):
|
||||
"""Test install_ilspy when already installed."""
|
||||
mock_status = {
|
||||
"dotnet_available": True,
|
||||
"dotnet_version": "8.0.100",
|
||||
"ilspycmd_available": True,
|
||||
"ilspycmd_version": "8.2.0",
|
||||
"ilspycmd_path": "/home/user/.dotnet/tools/ilspycmd",
|
||||
}
|
||||
|
||||
with patch.object(server, "_check_dotnet_tools", return_value=mock_status):
|
||||
result = await server.install_ilspy()
|
||||
|
||||
assert "already installed" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_install_fails(self):
|
||||
"""Test install_ilspy when installation fails."""
|
||||
mock_status_before = {
|
||||
"dotnet_available": True,
|
||||
"dotnet_version": "8.0.100",
|
||||
"ilspycmd_available": False,
|
||||
"ilspycmd_version": None,
|
||||
"ilspycmd_path": None,
|
||||
}
|
||||
|
||||
# Mock subprocess to simulate installation failure
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.communicate = AsyncMock(return_value=(b"", b"Installation failed"))
|
||||
|
||||
with (
|
||||
patch.object(server, "_check_dotnet_tools", return_value=mock_status_before),
|
||||
patch("asyncio.create_subprocess_exec", return_value=mock_proc),
|
||||
):
|
||||
result = await server.install_ilspy()
|
||||
|
||||
assert "Installation failed" in result or "failed" in result.lower()
|
||||
@ -3,9 +3,9 @@
|
||||
import pytest
|
||||
|
||||
from mcilspy.models import (
|
||||
DecompileRequest,
|
||||
EntityType,
|
||||
LanguageVersion,
|
||||
DecompileRequest,
|
||||
ListTypesRequest,
|
||||
TypeInfo,
|
||||
)
|
||||
|
||||
288
tests/test_security.py
Normal file
288
tests/test_security.py
Normal file
@ -0,0 +1,288 @@
|
||||
"""Security-focused tests for mcilspy validation functions.
|
||||
|
||||
These tests verify the security hardening in S1-S4:
|
||||
- S1: Path traversal prevention via _validate_assembly_path()
|
||||
- S2: Temp directory race condition fix (structural - uses TemporaryDirectory)
|
||||
- S3: Subprocess output size limits (MAX_OUTPUT_BYTES)
|
||||
- S4: Assembly file size limits (MAX_ASSEMBLY_SIZE_MB)
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestValidateAssemblyPath:
|
||||
"""Tests for S1: _validate_assembly_path() security validation."""
|
||||
|
||||
def test_empty_path_rejected(self):
|
||||
"""Empty paths should be rejected."""
|
||||
from mcilspy.server import AssemblyPathError, _validate_assembly_path
|
||||
|
||||
with pytest.raises(AssemblyPathError, match="cannot be empty"):
|
||||
_validate_assembly_path("")
|
||||
|
||||
with pytest.raises(AssemblyPathError, match="cannot be empty"):
|
||||
_validate_assembly_path(" ")
|
||||
|
||||
def test_nonexistent_file_rejected(self):
|
||||
"""Non-existent files should be rejected."""
|
||||
from mcilspy.server import AssemblyPathError, _validate_assembly_path
|
||||
|
||||
with pytest.raises(AssemblyPathError, match="not found"):
|
||||
_validate_assembly_path("/nonexistent/path/to/assembly.dll")
|
||||
|
||||
def test_directory_rejected(self):
|
||||
"""Directories should be rejected (file required)."""
|
||||
from mcilspy.server import AssemblyPathError, _validate_assembly_path
|
||||
|
||||
with pytest.raises(AssemblyPathError, match="not a file"):
|
||||
_validate_assembly_path("/tmp")
|
||||
|
||||
def test_invalid_extension_rejected(self):
|
||||
"""Files without .dll or .exe extension should be rejected."""
|
||||
from mcilspy.server import AssemblyPathError, _validate_assembly_path
|
||||
|
||||
# Create a temp file with wrong extension
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
with pytest.raises(AssemblyPathError, match="Invalid assembly extension"):
|
||||
_validate_assembly_path(temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_valid_dll_accepted(self):
|
||||
"""Valid .dll files should be accepted."""
|
||||
from mcilspy.server import _validate_assembly_path
|
||||
|
||||
# Create a temp file with .dll extension
|
||||
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = _validate_assembly_path(temp_path)
|
||||
# Should return absolute path
|
||||
assert os.path.isabs(result)
|
||||
assert result.endswith(".dll")
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_valid_exe_accepted(self):
|
||||
"""Valid .exe files should be accepted."""
|
||||
from mcilspy.server import _validate_assembly_path
|
||||
|
||||
# Create a temp file with .exe extension
|
||||
with tempfile.NamedTemporaryFile(suffix=".exe", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = _validate_assembly_path(temp_path)
|
||||
# Should return absolute path
|
||||
assert os.path.isabs(result)
|
||||
assert result.endswith(".exe")
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_case_insensitive_extension(self):
|
||||
"""Extension check should be case-insensitive."""
|
||||
from mcilspy.server import _validate_assembly_path
|
||||
|
||||
# Create temp files with mixed case extensions
|
||||
for ext in [".DLL", ".Dll", ".EXE", ".Exe"]:
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
result = _validate_assembly_path(temp_path)
|
||||
assert os.path.isabs(result)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_path_traversal_resolved(self):
|
||||
"""Path traversal attempts should be resolved to absolute paths."""
|
||||
from mcilspy.server import _validate_assembly_path
|
||||
|
||||
# Create a temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Create a path with .. components
|
||||
parent = os.path.dirname(temp_path)
|
||||
filename = os.path.basename(temp_path)
|
||||
traversal_path = os.path.join(parent, "..", os.path.basename(parent), filename)
|
||||
|
||||
result = _validate_assembly_path(traversal_path)
|
||||
# Should resolve to absolute path without ..
|
||||
assert ".." not in result
|
||||
assert os.path.isabs(result)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_tilde_expansion(self):
|
||||
"""Home directory tilde should be expanded."""
|
||||
from mcilspy.server import AssemblyPathError, _validate_assembly_path
|
||||
|
||||
# This should fail because the file doesn't exist, but the path should be expanded
|
||||
with pytest.raises(AssemblyPathError, match="not found"):
|
||||
_validate_assembly_path("~/nonexistent.dll")
|
||||
|
||||
|
||||
class TestMaxOutputBytes:
|
||||
"""Tests for S3: MAX_OUTPUT_BYTES constant and truncation."""
|
||||
|
||||
def test_constant_defined(self):
|
||||
"""MAX_OUTPUT_BYTES constant should be defined."""
|
||||
from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES
|
||||
|
||||
assert MAX_OUTPUT_BYTES == 50_000_000 # 50 MB
|
||||
assert MAX_OUTPUT_BYTES > 0
|
||||
|
||||
def test_constant_reasonable_size(self):
|
||||
"""MAX_OUTPUT_BYTES should be a reasonable size (not too small, not too large)."""
|
||||
from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES
|
||||
|
||||
# Should be at least 1 MB
|
||||
assert MAX_OUTPUT_BYTES >= 1_000_000
|
||||
# Should be at most 100 MB to prevent memory issues
|
||||
assert MAX_OUTPUT_BYTES <= 100_000_000
|
||||
|
||||
|
||||
class TestMaxAssemblySize:
|
||||
"""Tests for S4: MAX_ASSEMBLY_SIZE_MB constant and size check."""
|
||||
|
||||
def test_constant_defined(self):
|
||||
"""MAX_ASSEMBLY_SIZE_MB constant should be defined."""
|
||||
from mcilspy.metadata_reader import MAX_ASSEMBLY_SIZE_MB
|
||||
|
||||
assert MAX_ASSEMBLY_SIZE_MB == 500 # 500 MB
|
||||
assert MAX_ASSEMBLY_SIZE_MB > 0
|
||||
|
||||
def test_assembly_size_error_defined(self):
|
||||
"""AssemblySizeError exception should be defined."""
|
||||
from mcilspy.metadata_reader import AssemblySizeError
|
||||
|
||||
assert issubclass(AssemblySizeError, ValueError)
|
||||
|
||||
def test_oversized_file_rejected(self):
|
||||
"""Files exceeding MAX_ASSEMBLY_SIZE_MB should be rejected."""
|
||||
from mcilspy.metadata_reader import (
|
||||
AssemblySizeError,
|
||||
MAX_ASSEMBLY_SIZE_MB,
|
||||
MetadataReader,
|
||||
)
|
||||
|
||||
# Create a temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Mock the file size to be over the limit
|
||||
mock_stat_result = os.stat(temp_path)
|
||||
oversized_bytes = (MAX_ASSEMBLY_SIZE_MB + 1) * 1024 * 1024
|
||||
|
||||
with patch.object(Path, "stat") as mock_stat:
|
||||
mock_stat.return_value = type(
|
||||
"StatResult",
|
||||
(),
|
||||
{"st_size": oversized_bytes},
|
||||
)()
|
||||
|
||||
with pytest.raises(AssemblySizeError, match="exceeds maximum"):
|
||||
MetadataReader(temp_path)
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
def test_normal_sized_file_accepted(self):
|
||||
"""Files under MAX_ASSEMBLY_SIZE_MB should be accepted (at init)."""
|
||||
from mcilspy.metadata_reader import MetadataReader
|
||||
|
||||
# Create a small temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
|
||||
f.write(b"small content")
|
||||
temp_path = f.name
|
||||
|
||||
try:
|
||||
# Should not raise on init (will fail later when trying to parse)
|
||||
reader = MetadataReader(temp_path)
|
||||
assert reader.assembly_path.exists()
|
||||
finally:
|
||||
os.unlink(temp_path)
|
||||
|
||||
|
||||
class TestTemporaryDirectoryUsage:
|
||||
"""Tests for S2: TemporaryDirectory context manager usage.
|
||||
|
||||
These are structural tests that verify the code uses the secure pattern.
|
||||
"""
|
||||
|
||||
def test_decompile_uses_temp_directory_context(self):
|
||||
"""Verify decompile method structure uses TemporaryDirectory."""
|
||||
import inspect
|
||||
|
||||
from mcilspy.ilspy_wrapper import ILSpyWrapper
|
||||
|
||||
# Get source code of decompile method
|
||||
source = inspect.getsource(ILSpyWrapper.decompile)
|
||||
|
||||
# Should use TemporaryDirectory context manager
|
||||
assert "tempfile.TemporaryDirectory()" in source
|
||||
assert "with tempfile.TemporaryDirectory()" in source
|
||||
|
||||
# Should NOT use the old mkdtemp pattern
|
||||
assert "tempfile.mkdtemp()" not in source
|
||||
|
||||
def test_get_assembly_info_uses_temp_directory_context(self):
|
||||
"""Verify get_assembly_info method structure uses TemporaryDirectory."""
|
||||
import inspect
|
||||
|
||||
from mcilspy.ilspy_wrapper import ILSpyWrapper
|
||||
|
||||
# Get source code of get_assembly_info method
|
||||
source = inspect.getsource(ILSpyWrapper.get_assembly_info)
|
||||
|
||||
# Should use TemporaryDirectory context manager
|
||||
assert "with tempfile.TemporaryDirectory()" in source
|
||||
|
||||
# Should NOT use the old mkdtemp pattern
|
||||
assert "tempfile.mkdtemp()" not in source
|
||||
|
||||
|
||||
class TestPathValidationIntegration:
|
||||
"""Integration tests to verify path validation is applied to all tools."""
|
||||
|
||||
def test_all_tools_have_path_validation(self):
|
||||
"""Verify all assembly-accepting tools call _validate_assembly_path."""
|
||||
import inspect
|
||||
|
||||
from mcilspy import server
|
||||
|
||||
# List of tool functions that accept assembly_path
|
||||
tools_with_assembly_path = [
|
||||
"decompile_assembly",
|
||||
"list_types",
|
||||
"generate_diagrammer",
|
||||
"get_assembly_info",
|
||||
"search_types",
|
||||
"search_strings",
|
||||
"search_methods",
|
||||
"search_fields",
|
||||
"search_properties",
|
||||
"list_events",
|
||||
"list_resources",
|
||||
"get_metadata_summary",
|
||||
]
|
||||
|
||||
for tool_name in tools_with_assembly_path:
|
||||
tool_func = getattr(server, tool_name)
|
||||
source = inspect.getsource(tool_func)
|
||||
|
||||
# Each tool should call _validate_assembly_path
|
||||
assert "_validate_assembly_path" in source, (
|
||||
f"Tool '{tool_name}' does not call _validate_assembly_path"
|
||||
)
|
||||
597
tests/test_server_tools.py
Normal file
597
tests/test_server_tools.py
Normal file
@ -0,0 +1,597 @@
|
||||
"""Tests for MCP server tool functions.
|
||||
|
||||
These tests exercise the @mcp.tool() decorated functions in server.py.
|
||||
We mock the ILSpyWrapper to test the tool logic independently of ilspycmd.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcilspy import server
|
||||
from mcilspy.models import (
|
||||
AssemblyInfo,
|
||||
DecompileResponse,
|
||||
ListTypesResponse,
|
||||
TypeInfo,
|
||||
)
|
||||
from mcilspy.utils import find_ilspycmd_path
|
||||
|
||||
|
||||
# Fixture to bypass path validation for tests using mock paths
|
||||
@pytest.fixture
|
||||
def bypass_path_validation():
|
||||
"""Bypass _validate_assembly_path for tests using mock wrapper."""
|
||||
def passthrough(path):
|
||||
return path
|
||||
with patch.object(server, "_validate_assembly_path", side_effect=passthrough):
|
||||
yield
|
||||
|
||||
|
||||
class TestCheckIlspyInstallation:
|
||||
"""Tests for check_ilspy_installation tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_both_installed(self):
|
||||
"""Test when both dotnet and ilspycmd are installed."""
|
||||
mock_status = {
|
||||
"dotnet_available": True,
|
||||
"dotnet_version": "8.0.100",
|
||||
"ilspycmd_available": True,
|
||||
"ilspycmd_version": "8.2.0",
|
||||
"ilspycmd_path": "/home/user/.dotnet/tools/ilspycmd",
|
||||
}
|
||||
|
||||
with patch.object(server, "_check_dotnet_tools", return_value=mock_status):
|
||||
result = await server.check_ilspy_installation()
|
||||
|
||||
assert "dotnet CLI" in result
|
||||
assert "8.0.100" in result
|
||||
assert "ilspycmd" in result
|
||||
assert "8.2.0" in result
|
||||
assert "ready to use" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dotnet_not_installed(self):
|
||||
"""Test when dotnet is not installed."""
|
||||
mock_status = {
|
||||
"dotnet_available": False,
|
||||
"dotnet_version": None,
|
||||
"ilspycmd_available": False,
|
||||
"ilspycmd_version": None,
|
||||
"ilspycmd_path": None,
|
||||
}
|
||||
|
||||
with patch.object(server, "_check_dotnet_tools", return_value=mock_status):
|
||||
result = await server.check_ilspy_installation()
|
||||
|
||||
assert "Not found" in result
|
||||
assert "dotnet.microsoft.com" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ilspycmd_not_installed(self):
|
||||
"""Test when dotnet is installed but ilspycmd is not."""
|
||||
mock_status = {
|
||||
"dotnet_available": True,
|
||||
"dotnet_version": "8.0.100",
|
||||
"ilspycmd_available": False,
|
||||
"ilspycmd_version": None,
|
||||
"ilspycmd_path": None,
|
||||
}
|
||||
|
||||
with patch.object(server, "_check_dotnet_tools", return_value=mock_status):
|
||||
result = await server.check_ilspy_installation()
|
||||
|
||||
assert "ilspycmd" in result
|
||||
assert "Not installed" in result
|
||||
assert "install_ilspy" in result.lower() or "dotnet tool install" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestDecompileAssembly:
|
||||
"""Tests for decompile_assembly tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_decompile(self):
|
||||
"""Test successful decompilation returns formatted output."""
|
||||
mock_response = DecompileResponse(
|
||||
success=True,
|
||||
assembly_name="TestAssembly",
|
||||
type_name="MyClass",
|
||||
source_code="public class MyClass { }",
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.decompile_assembly("/path/to/test.dll")
|
||||
|
||||
assert "Decompilation result" in result
|
||||
assert "TestAssembly" in result
|
||||
assert "public class MyClass" in result
|
||||
assert "```csharp" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_with_output_dir(self):
|
||||
"""Test decompilation to output directory."""
|
||||
mock_response = DecompileResponse(
|
||||
success=True,
|
||||
assembly_name="TestAssembly",
|
||||
output_path="/tmp/output",
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.decompile_assembly(
|
||||
"/path/to/test.dll", output_dir="/tmp/output"
|
||||
)
|
||||
|
||||
assert "successful" in result.lower()
|
||||
assert "/tmp/output" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_failure(self):
|
||||
"""Test failed decompilation returns error message."""
|
||||
mock_response = DecompileResponse(
|
||||
success=False,
|
||||
assembly_name="TestAssembly",
|
||||
error_message="Assembly not found",
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.decompile_assembly("/path/to/nonexistent.dll")
|
||||
|
||||
assert "failed" in result.lower()
|
||||
assert "Assembly not found" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_with_type_name(self):
|
||||
"""Test decompiling a specific type."""
|
||||
mock_response = DecompileResponse(
|
||||
success=True,
|
||||
assembly_name="TestAssembly",
|
||||
type_name="MyNamespace.MyClass",
|
||||
source_code="namespace MyNamespace { public class MyClass { } }",
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.decompile_assembly(
|
||||
"/path/to/test.dll", type_name="MyNamespace.MyClass"
|
||||
)
|
||||
|
||||
assert "MyNamespace.MyClass" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompile_exception_handling(self):
|
||||
"""Test that exceptions are handled gracefully."""
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.decompile = AsyncMock(side_effect=RuntimeError("Test error"))
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.decompile_assembly("/path/to/test.dll")
|
||||
|
||||
assert "Error" in result
|
||||
assert "Test error" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestListTypes:
|
||||
"""Tests for list_types tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_success(self):
|
||||
"""Test successful type listing."""
|
||||
mock_types = [
|
||||
TypeInfo(name="ClassA", full_name="NS.ClassA", kind="Class", namespace="NS"),
|
||||
TypeInfo(name="ClassB", full_name="NS.ClassB", kind="Class", namespace="NS"),
|
||||
TypeInfo(name="IService", full_name="NS.IService", kind="Interface", namespace="NS"),
|
||||
]
|
||||
mock_response = ListTypesResponse(
|
||||
success=True,
|
||||
types=mock_types,
|
||||
total_count=3,
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.list_types("/path/to/test.dll")
|
||||
|
||||
assert "Types in" in result
|
||||
# New pagination format: "Showing X of Y types"
|
||||
assert "Showing 3 of 3 types" in result or "Found 3 types" in result
|
||||
assert "ClassA" in result
|
||||
assert "ClassB" in result
|
||||
assert "IService" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_grouped_by_namespace(self):
|
||||
"""Test that types are grouped by namespace."""
|
||||
mock_types = [
|
||||
TypeInfo(name="ClassA", full_name="NS1.ClassA", kind="Class", namespace="NS1"),
|
||||
TypeInfo(name="ClassB", full_name="NS2.ClassB", kind="Class", namespace="NS2"),
|
||||
]
|
||||
mock_response = ListTypesResponse(
|
||||
success=True,
|
||||
types=mock_types,
|
||||
total_count=2,
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.list_types("/path/to/test.dll")
|
||||
|
||||
# Should have namespace headers
|
||||
assert "## NS1" in result
|
||||
assert "## NS2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_with_entity_types(self):
|
||||
"""Test listing specific entity types."""
|
||||
mock_response = ListTypesResponse(
|
||||
success=True,
|
||||
types=[TypeInfo(name="IService", full_name="NS.IService", kind="Interface", namespace="NS")],
|
||||
total_count=1,
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.list_types(
|
||||
"/path/to/test.dll", entity_types=["interface"]
|
||||
)
|
||||
|
||||
assert "IService" in result
|
||||
mock_wrapper.list_types.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_types_no_types_found(self):
|
||||
"""Test when no types are found."""
|
||||
mock_response = ListTypesResponse(
|
||||
success=True,
|
||||
types=[],
|
||||
total_count=0,
|
||||
error_message="No types found in assembly",
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.list_types("/path/to/test.dll")
|
||||
|
||||
assert "No types found" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestSearchTypes:
|
||||
"""Tests for search_types tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_types_finds_matches(self):
|
||||
"""Test searching types by pattern."""
|
||||
mock_types = [
|
||||
TypeInfo(name="UserService", full_name="NS.UserService", kind="Class", namespace="NS"),
|
||||
TypeInfo(name="OrderService", full_name="NS.OrderService", kind="Class", namespace="NS"),
|
||||
TypeInfo(name="Helper", full_name="NS.Helper", kind="Class", namespace="NS"),
|
||||
]
|
||||
mock_response = ListTypesResponse(
|
||||
success=True,
|
||||
types=mock_types,
|
||||
total_count=3,
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.search_types("/path/to/test.dll", pattern="Service")
|
||||
|
||||
assert "Search Results" in result
|
||||
assert "UserService" in result
|
||||
assert "OrderService" in result
|
||||
assert "Helper" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_types_case_insensitive(self):
|
||||
"""Test case-insensitive search (default)."""
|
||||
mock_types = [
|
||||
TypeInfo(name="SERVICE", full_name="NS.SERVICE", kind="Class", namespace="NS"),
|
||||
]
|
||||
mock_response = ListTypesResponse(success=True, types=mock_types, total_count=1)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.search_types(
|
||||
"/path/to/test.dll", pattern="service", case_sensitive=False
|
||||
)
|
||||
|
||||
assert "SERVICE" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_types_with_namespace_filter(self):
|
||||
"""Test searching with namespace filter."""
|
||||
mock_types = [
|
||||
TypeInfo(name="ClassA", full_name="App.Services.ClassA", kind="Class", namespace="App.Services"),
|
||||
TypeInfo(name="ClassB", full_name="App.Models.ClassB", kind="Class", namespace="App.Models"),
|
||||
]
|
||||
mock_response = ListTypesResponse(success=True, types=mock_types, total_count=2)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.search_types(
|
||||
"/path/to/test.dll", pattern="Class", namespace_filter="Services"
|
||||
)
|
||||
|
||||
assert "ClassA" in result
|
||||
assert "ClassB" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_types_no_matches(self):
|
||||
"""Test when no types match the pattern."""
|
||||
mock_types = [
|
||||
TypeInfo(name="Helper", full_name="NS.Helper", kind="Class", namespace="NS"),
|
||||
]
|
||||
mock_response = ListTypesResponse(success=True, types=mock_types, total_count=1)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.list_types = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.search_types("/path/to/test.dll", pattern="Service")
|
||||
|
||||
assert "No types found" in result
|
||||
|
||||
|
||||
class TestSearchMethods:
|
||||
"""Tests for search_methods tool (using metadata reader)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_uses_metadata_reader(self, test_assembly_path):
|
||||
"""Test that search_methods uses MetadataReader directly."""
|
||||
result = await server.search_methods(test_assembly_path, pattern="Do")
|
||||
|
||||
assert "DoSomething" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_filters_by_pattern(self, test_assembly_path):
|
||||
"""Test method name pattern filtering."""
|
||||
result = await server.search_methods(test_assembly_path, pattern="Get")
|
||||
|
||||
assert "GetGreeting" in result
|
||||
assert "DoSomething" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_methods_public_only(self, test_assembly_path):
|
||||
"""Test filtering for public methods."""
|
||||
result = await server.search_methods(
|
||||
test_assembly_path, pattern="Method", public_only=True
|
||||
)
|
||||
|
||||
# Should find public methods
|
||||
assert "Method" in result or "No methods found" in result
|
||||
|
||||
|
||||
class TestSearchFields:
|
||||
"""Tests for search_fields tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fields_finds_constants(self, test_assembly_path):
|
||||
"""Test finding constant fields."""
|
||||
result = await server.search_fields(test_assembly_path, pattern="API")
|
||||
|
||||
assert "API_KEY" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_fields_constants_only(self, test_assembly_path):
|
||||
"""Test filtering for constants only."""
|
||||
result = await server.search_fields(
|
||||
test_assembly_path, pattern="", constants_only=True
|
||||
)
|
||||
|
||||
assert "const" in result
|
||||
|
||||
|
||||
class TestSearchProperties:
|
||||
"""Tests for search_properties tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_properties(self, test_assembly_path):
|
||||
"""Test searching for properties."""
|
||||
result = await server.search_properties(test_assembly_path, pattern="Name")
|
||||
|
||||
assert "Name" in result
|
||||
|
||||
|
||||
class TestListEvents:
|
||||
"""Tests for list_events tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_events(self, test_assembly_path):
|
||||
"""Test listing events from assembly."""
|
||||
result = await server.list_events(test_assembly_path)
|
||||
|
||||
assert "OnChange" in result or "No events" in result
|
||||
|
||||
|
||||
class TestListResources:
|
||||
"""Tests for list_resources tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_resources_empty(self, test_assembly_path):
|
||||
"""Test listing resources (test assembly has none)."""
|
||||
result = await server.list_resources(test_assembly_path)
|
||||
|
||||
assert "No embedded resources" in result or "Embedded Resources" in result
|
||||
|
||||
|
||||
class TestGetMetadataSummary:
|
||||
"""Tests for get_metadata_summary tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metadata_summary(self, test_assembly_path):
|
||||
"""Test getting metadata summary."""
|
||||
result = await server.get_metadata_summary(test_assembly_path)
|
||||
|
||||
assert "Assembly Metadata Summary" in result
|
||||
assert "Name" in result
|
||||
assert "Version" in result
|
||||
assert "Statistics" in result
|
||||
assert "Types" in result
|
||||
assert "Methods" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestGetAssemblyInfo:
|
||||
"""Tests for get_assembly_info tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembly_info_success(self):
|
||||
"""Test getting assembly info successfully."""
|
||||
mock_info = AssemblyInfo(
|
||||
name="TestAssembly",
|
||||
full_name="TestAssembly, Version=1.0.0.0",
|
||||
location="/path/to/test.dll",
|
||||
version="1.0.0.0",
|
||||
target_framework=".NETStandard,Version=v2.0",
|
||||
is_signed=False,
|
||||
has_debug_info=False,
|
||||
)
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.get_assembly_info = AsyncMock(return_value=mock_info)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.get_assembly_info("/path/to/test.dll")
|
||||
|
||||
assert "Assembly Information" in result
|
||||
assert "TestAssembly" in result
|
||||
assert "1.0.0.0" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_assembly_info_exception(self):
|
||||
"""Test handling of exceptions in get_assembly_info."""
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.get_assembly_info = AsyncMock(
|
||||
side_effect=FileNotFoundError("File not found")
|
||||
)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.get_assembly_info("/nonexistent/file.dll")
|
||||
|
||||
assert "Error" in result
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("bypass_path_validation")
|
||||
class TestGenerateDiagrammer:
|
||||
"""Tests for generate_diagrammer tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_diagrammer_success(self):
|
||||
"""Test successful diagram generation."""
|
||||
mock_response = {
|
||||
"success": True,
|
||||
"output_directory": "/tmp/diagrammer",
|
||||
}
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.generate_diagrammer = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.generate_diagrammer("/path/to/test.dll")
|
||||
|
||||
assert "successfully" in result.lower()
|
||||
assert "/tmp/diagrammer" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_diagrammer_failure(self):
|
||||
"""Test failed diagram generation."""
|
||||
mock_response = {
|
||||
"success": False,
|
||||
"error_message": "Failed to generate diagram",
|
||||
}
|
||||
|
||||
mock_wrapper = MagicMock()
|
||||
mock_wrapper.generate_diagrammer = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(server, "get_wrapper", return_value=mock_wrapper):
|
||||
result = await server.generate_diagrammer("/path/to/test.dll")
|
||||
|
||||
assert "Failed" in result
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for helper functions in server.py."""
|
||||
|
||||
def test_format_error_with_context(self):
|
||||
"""Test _format_error with context."""
|
||||
error = ValueError("test error")
|
||||
result = server._format_error(error, "testing")
|
||||
|
||||
assert "Error" in result
|
||||
assert "testing" in result
|
||||
assert "test error" in result
|
||||
|
||||
def test_format_error_without_context(self):
|
||||
"""Test _format_error without context."""
|
||||
error = RuntimeError("something went wrong")
|
||||
result = server._format_error(error)
|
||||
|
||||
assert "Error" in result
|
||||
assert "something went wrong" in result
|
||||
|
||||
def test_find_ilspycmd_path_not_installed(self):
|
||||
"""Test find_ilspycmd_path when not installed."""
|
||||
with (
|
||||
patch("mcilspy.utils.shutil.which", return_value=None),
|
||||
patch("mcilspy.utils.os.path.isfile", return_value=False),
|
||||
):
|
||||
result = find_ilspycmd_path()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_find_ilspycmd_path_in_path(self):
|
||||
"""Test find_ilspycmd_path when in PATH."""
|
||||
with patch("mcilspy.utils.shutil.which", return_value="/usr/local/bin/ilspycmd"):
|
||||
result = find_ilspycmd_path()
|
||||
|
||||
assert result == "/usr/local/bin/ilspycmd"
|
||||
|
||||
def test_detect_platform_linux(self):
|
||||
"""Test platform detection on Linux."""
|
||||
with (
|
||||
patch("platform.system", return_value="Linux"),
|
||||
patch("builtins.open", MagicMock()),
|
||||
patch("shutil.which", return_value="/usr/bin/pacman"),
|
||||
):
|
||||
result = server._detect_platform()
|
||||
|
||||
assert result["system"] == "linux"
|
||||
assert result["package_manager"] is not None
|
||||
|
||||
def test_detect_platform_windows(self):
|
||||
"""Test platform detection on Windows."""
|
||||
with (
|
||||
patch("platform.system", return_value="Windows"),
|
||||
patch("shutil.which", return_value=None),
|
||||
):
|
||||
result = server._detect_platform()
|
||||
|
||||
assert result["system"] == "windows"
|
||||
261
tests/test_timeout.py
Normal file
261
tests/test_timeout.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""Tests for timeout behavior.
|
||||
|
||||
Verifies that the 5-minute timeout in ILSpyWrapper works correctly
|
||||
and that hanging processes are properly killed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcilspy.ilspy_wrapper import ILSpyWrapper
|
||||
from mcilspy.models import DecompileRequest, LanguageVersion, ListTypesRequest
|
||||
|
||||
|
||||
class TestTimeoutBehavior:
|
||||
"""Tests for the 5-minute timeout in _run_command."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
"""Create a wrapper with mocked ilspycmd path."""
|
||||
with patch("mcilspy.utils.find_ilspycmd_path", return_value="/mock/ilspycmd"):
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_returns_error_message(self, wrapper):
|
||||
"""Test that timeout produces appropriate error message."""
|
||||
# Create a mock process that never completes
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
return_code, stdout, stderr = await wrapper._run_command(["test", "args"])
|
||||
|
||||
assert return_code == -1
|
||||
assert "timed out" in stderr.lower()
|
||||
assert "5 minutes" in stderr
|
||||
mock_process.kill.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_killed_on_timeout(self, wrapper):
|
||||
"""Test that the process is killed when timeout occurs."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
await wrapper._run_command(["test"])
|
||||
|
||||
# Verify kill was called
|
||||
mock_process.kill.assert_called_once()
|
||||
# Verify we waited for the process to clean up
|
||||
mock_process.wait.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_in_decompile(self, wrapper, test_assembly_path):
|
||||
"""Test timeout behavior during decompile operation."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
language_version=LanguageVersion.LATEST,
|
||||
)
|
||||
response = await wrapper.decompile(request)
|
||||
|
||||
assert response.success is False
|
||||
assert "timed out" in response.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_in_list_types(self, wrapper, test_assembly_path):
|
||||
"""Test timeout behavior during list_types operation."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
request = ListTypesRequest(assembly_path=test_assembly_path)
|
||||
response = await wrapper.list_types(request)
|
||||
|
||||
assert response.success is False
|
||||
assert "timed out" in response.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_value_is_5_minutes(self, wrapper):
|
||||
"""Verify the timeout value is 300 seconds (5 minutes)."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(return_value=(b"output", b""))
|
||||
mock_process.returncode = 0
|
||||
|
||||
with (
|
||||
patch("asyncio.create_subprocess_exec", return_value=mock_process),
|
||||
patch("asyncio.wait_for") as mock_wait_for,
|
||||
):
|
||||
# Set up the mock to return the communicate result
|
||||
mock_wait_for.return_value = (b"output", b"")
|
||||
|
||||
await wrapper._run_command(["test"])
|
||||
|
||||
# Verify wait_for was called with 300 second timeout
|
||||
mock_wait_for.assert_called_once()
|
||||
args, kwargs = mock_wait_for.call_args
|
||||
assert kwargs.get("timeout") == 300.0 or args[1] == 300.0
|
||||
|
||||
|
||||
class TestNormalOperationWithTimeout:
|
||||
"""Tests that normal operations complete successfully within timeout."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
"""Create a wrapper with mocked ilspycmd path."""
|
||||
with patch("mcilspy.utils.find_ilspycmd_path", return_value="/mock/ilspycmd"):
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_operation_completes(self, wrapper):
|
||||
"""Test that fast operations complete normally."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(return_value=(b"success output", b""))
|
||||
mock_process.returncode = 0
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
return_code, stdout, stderr = await wrapper._run_command(["test"])
|
||||
|
||||
assert return_code == 0
|
||||
assert stdout == "success output"
|
||||
assert stderr == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_operation_with_stderr_completes(self, wrapper):
|
||||
"""Test that operations with stderr output complete normally."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(
|
||||
return_value=(b"output", b"warning message")
|
||||
)
|
||||
mock_process.returncode = 0
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
return_code, stdout, stderr = await wrapper._run_command(["test"])
|
||||
|
||||
assert return_code == 0
|
||||
assert stdout == "output"
|
||||
assert stderr == "warning message"
|
||||
|
||||
|
||||
class TestTimeoutWithAsyncioWaitFor:
|
||||
"""Tests verifying asyncio.wait_for is used correctly."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
"""Create a wrapper with mocked ilspycmd path."""
|
||||
with patch("mcilspy.utils.find_ilspycmd_path", return_value="/mock/ilspycmd"):
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_for_is_used(self, wrapper):
|
||||
"""Verify that asyncio.wait_for is used for timeout."""
|
||||
# Read the source code and verify wait_for is used
|
||||
import inspect
|
||||
source = inspect.getsource(wrapper._run_command)
|
||||
assert "asyncio.wait_for" in source or "wait_for" in source
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_value_in_source(self, wrapper):
|
||||
"""Verify timeout is configured via constants."""
|
||||
import inspect
|
||||
source = inspect.getsource(wrapper._run_command)
|
||||
# Should use DECOMPILE_TIMEOUT_SECONDS constant or have timeout reference
|
||||
assert "DECOMPILE_TIMEOUT_SECONDS" in source or "timeout" in source.lower()
|
||||
|
||||
|
||||
class TestTimeoutCleanup:
|
||||
"""Tests for proper cleanup after timeout."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
"""Create a wrapper with mocked ilspycmd path."""
|
||||
with patch("mcilspy.utils.find_ilspycmd_path", return_value="/mock/ilspycmd"):
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_zombie_process_after_timeout(self, wrapper):
|
||||
"""Verify process is properly cleaned up after timeout."""
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
mock_process.returncode = None
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
await wrapper._run_command(["test"])
|
||||
|
||||
# kill() followed by wait() ensures no zombie
|
||||
mock_process.kill.assert_called_once()
|
||||
mock_process.wait.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_temp_files_cleaned_after_timeout(self, wrapper, test_assembly_path):
|
||||
"""Verify temporary files are cleaned up after timeout."""
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.communicate = AsyncMock(side_effect=asyncio.TimeoutError())
|
||||
mock_process.kill = MagicMock()
|
||||
mock_process.wait = AsyncMock()
|
||||
|
||||
initial_temp_count = len(os.listdir(tempfile.gettempdir()))
|
||||
|
||||
with patch("asyncio.create_subprocess_exec", return_value=mock_process):
|
||||
request = DecompileRequest(
|
||||
assembly_path=test_assembly_path,
|
||||
)
|
||||
await wrapper.decompile(request)
|
||||
|
||||
# Temp directory should be cleaned up
|
||||
final_temp_count = len(os.listdir(tempfile.gettempdir()))
|
||||
# Should not have more temp files (may have same or fewer)
|
||||
assert final_temp_count <= initial_temp_count + 1 # Allow small margin
|
||||
|
||||
|
||||
class TestExceptionHandling:
|
||||
"""Tests for exception handling in _run_command."""
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self):
|
||||
"""Create a wrapper with mocked ilspycmd path."""
|
||||
with patch("mcilspy.utils.find_ilspycmd_path", return_value="/mock/ilspycmd"):
|
||||
return ILSpyWrapper()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_general_exception_handled(self, wrapper):
|
||||
"""Test that general exceptions are caught and returned."""
|
||||
with patch(
|
||||
"asyncio.create_subprocess_exec",
|
||||
side_effect=OSError("Cannot execute"),
|
||||
):
|
||||
return_code, stdout, stderr = await wrapper._run_command(["test"])
|
||||
|
||||
assert return_code == -1
|
||||
assert stdout == ""
|
||||
assert "Cannot execute" in stderr
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_permission_error_handled(self, wrapper):
|
||||
"""Test that permission errors are handled gracefully."""
|
||||
with patch(
|
||||
"asyncio.create_subprocess_exec",
|
||||
side_effect=PermissionError("Access denied"),
|
||||
):
|
||||
return_code, stdout, stderr = await wrapper._run_command(["test"])
|
||||
|
||||
assert return_code == -1
|
||||
assert "Access denied" in stderr
|
||||
Loading…
x
Reference in New Issue
Block a user