mcilspy/src/mcilspy/ilspy_wrapper.py
Ryan Malloy 20d0cd2e3a perf: major performance improvements and code quality fixes
- P1: search_strings now uses dnfile's #US heap directly instead of
  decompiling entire assembly, providing 10-100x speedup
- P2: add pagination (max_results/offset) to all list/search tools
- P5: add proper logging for platform detection failures
- P6: replace generic exception catches with specific exceptions
- P7: fix MetadataReader.__exit__ return type
- P8: add PE signature (MZ header) validation before invoking ilspycmd

All 35 tests pass, ruff check clean.
2026-02-08 11:40:25 -07:00

623 lines
23 KiB
Python

"""Wrapper for ICSharpCode.ILSpyCmd command line tool."""
import asyncio
import logging
import os
import re
import shutil
import tempfile
from pathlib import Path
from typing import Any
from .constants import (
DECOMPILE_TIMEOUT_SECONDS,
MAX_UNPARSED_LOG_LINES,
UNPARSED_LINE_PREVIEW_LENGTH,
)
from .models import (
AssemblyInfo,
AssemblyInfoRequest,
DecompileRequest,
DecompileResponse,
GenerateDiagrammerRequest,
ListTypesRequest,
ListTypesResponse,
TypeInfo,
)
from .utils import find_ilspycmd_path
logger = logging.getLogger(__name__)
# Maximum bytes to read from subprocess stdout/stderr to prevent memory exhaustion
# from malicious or corrupted assemblies that produce huge output
MAX_OUTPUT_BYTES = 50_000_000 # 50 MB
# PE file signature constants
_MZ_SIGNATURE = b"MZ" # DOS header magic number
def _validate_pe_signature(file_path: str) -> tuple[bool, str]:
"""Quick validation of PE file signature (MZ header).
Fails fast on non-PE files before invoking ilspycmd.
Args:
file_path: Path to the file to validate
Returns:
Tuple of (is_valid, error_message). error_message is empty if valid.
"""
try:
with open(file_path, "rb") as f:
header = f.read(2)
if len(header) < 2:
return False, "File is too small to be a valid PE file"
if header != _MZ_SIGNATURE:
return False, f"Not a valid PE file (missing MZ signature, got {header!r})"
return True, ""
except OSError as e:
return False, f"Cannot read file: {e}"
class ILSpyWrapper:
"""Wrapper class for ILSpy command line tool.
This class encapsulates all interactions with the ilspycmd CLI tool.
While the wrapper is stateless in terms of decompilation operations
(each call is independent), it exists as a class to:
1. Cache the ilspycmd path lookup - Finding the executable involves
checking PATH and several common installation locations, which
is relatively expensive. Caching this on instantiation avoids
repeated filesystem operations.
2. Provide a single point of configuration - If ilspycmd is not found,
we fail fast at wrapper creation rather than on each tool call.
3. Enable future extensions - The class structure allows adding
connection pooling, result caching, or other optimizations without
changing the API.
The wrapper should be created once and reused across tool calls.
See get_wrapper() in server.py for the recommended usage pattern.
"""
def __init__(self, ilspycmd_path: str | None = None) -> None:
"""Initialize the wrapper.
Args:
ilspycmd_path: Path to ilspycmd executable. If None, will try to find it in PATH.
"""
self.ilspycmd_path = ilspycmd_path or find_ilspycmd_path()
if not self.ilspycmd_path:
raise RuntimeError(
"ILSpyCmd not found. Please install it with: dotnet tool install --global ilspycmd"
)
async def _run_command(
self, args: list[str], input_data: str | None = None
) -> tuple[int, str, str]:
"""Run ilspycmd with given arguments.
Args:
args: Command line arguments
input_data: Optional input data to pass to stdin
Returns:
Tuple of (return_code, stdout, stderr)
Note:
Output is truncated to MAX_OUTPUT_BYTES to prevent memory exhaustion
from malicious or corrupted assemblies.
"""
cmd = [self.ilspycmd_path] + args
logger.debug(f"Running command: {' '.join(cmd)}")
try:
process = await asyncio.create_subprocess_exec(
*cmd,
stdin=asyncio.subprocess.PIPE if input_data else None,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
input_bytes = input_data.encode("utf-8") if input_data else None
# Timeout to prevent hanging on malicious/corrupted assemblies
try:
stdout_bytes, stderr_bytes = await asyncio.wait_for(
process.communicate(input=input_bytes),
timeout=DECOMPILE_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
timeout_mins = DECOMPILE_TIMEOUT_SECONDS / 60
logger.warning(f"Command timed out after {timeout_mins:.0f} minutes, killing process")
process.kill()
await process.wait() # Ensure process is cleaned up
return -1, "", f"Command timed out after {timeout_mins:.0f} minutes. The assembly may be corrupted or too complex."
# Truncate output if it exceeds the limit to prevent memory exhaustion
stdout_truncated = False
stderr_truncated = False
if stdout_bytes and len(stdout_bytes) > MAX_OUTPUT_BYTES:
stdout_bytes = stdout_bytes[:MAX_OUTPUT_BYTES]
stdout_truncated = True
logger.warning(
f"stdout truncated from {len(stdout_bytes)} to {MAX_OUTPUT_BYTES} bytes"
)
if stderr_bytes and len(stderr_bytes) > MAX_OUTPUT_BYTES:
stderr_bytes = stderr_bytes[:MAX_OUTPUT_BYTES]
stderr_truncated = True
logger.warning(
f"stderr truncated from {len(stderr_bytes)} to {MAX_OUTPUT_BYTES} bytes"
)
stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
# Add truncation warning to output
if stdout_truncated:
stdout += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
if stderr_truncated:
stderr += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
return process.returncode, stdout, stderr
except (OSError, FileNotFoundError) as e:
logger.exception(f"Error running ilspycmd command: {e}")
return -1, "", str(e)
async def decompile(self, request: DecompileRequest) -> DecompileResponse:
"""Decompile a .NET assembly.
Args:
request: Decompilation request
Returns:
Decompilation response
"""
if not os.path.exists(request.assembly_path):
return DecompileResponse(
success=False,
error_message=f"Assembly file not found: {request.assembly_path}",
assembly_name=os.path.basename(request.assembly_path),
)
# Validate PE signature before invoking ilspycmd
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
if not is_valid:
return DecompileResponse(
success=False,
error_message=pe_error,
assembly_name=os.path.basename(request.assembly_path),
)
# Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
# when user doesn't specify an output directory
if request.output_dir:
# User specified output directory - use it directly
return await self._decompile_to_dir(request, request.output_dir)
else:
# Create a temporary directory with guaranteed cleanup
with tempfile.TemporaryDirectory() as temp_dir:
return await self._decompile_to_dir(request, temp_dir)
async def _decompile_to_dir(
self, request: DecompileRequest, output_dir: str
) -> DecompileResponse:
"""Internal helper to perform decompilation to a specific directory.
Args:
request: Decompilation request
output_dir: Directory to write output to
Returns:
Decompilation response
"""
args = [request.assembly_path]
# Add language version
args.extend(["-lv", request.language_version.value])
# Add type filter if specified
if request.type_name:
args.extend(["-t", request.type_name])
args.extend(["-o", output_dir])
# Add project creation flag
if request.create_project:
args.append("-p")
# Add IL code flag
if request.show_il_code:
args.append("-il")
# Add reference paths
for ref_path in request.reference_paths:
args.extend(["-r", ref_path])
# Add optimization flags
if request.remove_dead_code:
args.append("--no-dead-code")
if request.remove_dead_stores:
args.append("--no-dead-stores")
# Add IL sequence points flag
if request.show_il_sequence_points:
args.append("--il-sequence-points")
# Add directory structure flag
if request.nested_directories:
args.append("--nested-directories")
# Disable update check for automation
args.append("--disable-updatecheck")
try:
return_code, stdout, stderr = await self._run_command(args)
assembly_name = os.path.splitext(os.path.basename(request.assembly_path))[0]
if return_code == 0:
# If no output directory was specified by user, return stdout as source code
source_code = None
output_path = None
if request.output_dir is None:
source_code = stdout
else:
output_path = output_dir
# Try to read the main generated file if it exists
if request.type_name:
# Single type decompilation
type_file = os.path.join(
output_dir, f"{request.type_name.split('.')[-1]}.cs"
)
if os.path.exists(type_file):
with open(type_file, encoding="utf-8") as f:
source_code = f.read()
elif not request.create_project:
# Single file decompilation
cs_file = os.path.join(output_dir, f"{assembly_name}.cs")
if os.path.exists(cs_file):
with open(cs_file, encoding="utf-8") as f:
source_code = f.read()
return DecompileResponse(
success=True,
source_code=source_code,
output_path=output_path,
assembly_name=assembly_name,
type_name=request.type_name,
)
else:
error_msg = stderr or stdout or "Unknown error occurred"
return DecompileResponse(
success=False,
error_message=error_msg,
assembly_name=assembly_name,
type_name=request.type_name,
)
except OSError as e:
logger.exception(f"Error during decompilation: {e}")
return DecompileResponse(
success=False,
error_message=str(e),
assembly_name=os.path.basename(request.assembly_path),
type_name=request.type_name,
)
async def list_types(self, request: ListTypesRequest) -> ListTypesResponse:
"""List types in a .NET assembly.
Args:
request: List types request
Returns:
List types response
"""
if not os.path.exists(request.assembly_path):
return ListTypesResponse(
success=False, error_message=f"Assembly file not found: {request.assembly_path}"
)
# Validate PE signature before invoking ilspycmd
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
if not is_valid:
return ListTypesResponse(success=False, error_message=pe_error)
args = [request.assembly_path]
# Add entity types to list
entity_types_str = "".join([et.value for et in request.entity_types])
args.extend(["-l", entity_types_str])
# Add reference paths
for ref_path in request.reference_paths:
args.extend(["-r", ref_path])
# Disable update check
args.append("--disable-updatecheck")
try:
return_code, stdout, stderr = await self._run_command(args)
if return_code == 0:
types = self._parse_types_output(stdout)
return ListTypesResponse(success=True, types=types, total_count=len(types))
else:
error_msg = stderr or stdout or "Unknown error occurred"
return ListTypesResponse(success=False, error_message=error_msg)
except OSError as e:
logger.exception(f"Error listing types: {e}")
return ListTypesResponse(success=False, error_message=str(e))
# Compiled regex for parsing ilspycmd list output
# Format: "TypeKind: FullTypeName" (e.g., "Class: MyNamespace.MyClass")
_TYPE_LINE_PATTERN = re.compile(r"^(\w+):\s*(.+)$")
def _parse_types_output(self, output: str) -> list[TypeInfo]:
"""Parse the output from list types command.
Args:
output: Raw output from ilspycmd
Returns:
List of TypeInfo objects
Note:
ilspycmd outputs types in format "TypeKind: FullTypeName"
Examples:
Class: MyNamespace.MyClass
Interface: MyNamespace.IService
Struct: MyNamespace.MyStruct
Class: MyNamespace.Outer+Nested (nested types)
"""
types = []
lines = output.strip().split("\n")
unparsed_count = 0
for line in lines:
line = line.strip()
if not line:
continue
match = self._TYPE_LINE_PATTERN.match(line)
if match:
kind = match.group(1)
full_name = match.group(2).strip()
# Extract namespace and name, handling nested types (+ separator)
name, namespace = self._split_type_name(full_name)
types.append(
TypeInfo(name=name, full_name=full_name, kind=kind, namespace=namespace)
)
else:
# Log unexpected lines (but don't fail - ilspycmd may output warnings/info)
unparsed_count += 1
if unparsed_count <= MAX_UNPARSED_LOG_LINES:
logger.debug(f"Skipping unparsed line from ilspycmd: {line[:UNPARSED_LINE_PREVIEW_LENGTH]}")
if unparsed_count > MAX_UNPARSED_LOG_LINES:
logger.debug(f"Skipped {unparsed_count} unparsed lines total")
return types
@staticmethod
def _split_type_name(full_name: str) -> tuple[str, str | None]:
"""Split a full type name into (name, namespace).
Handles:
- Simple types: "MyClass" -> ("MyClass", None)
- Namespaced types: "MyNamespace.MyClass" -> ("MyClass", "MyNamespace")
- Nested types: "MyNamespace.Outer+Nested" -> ("Outer+Nested", "MyNamespace")
- Deep nesting: "NS.Outer+Mid+Inner" -> ("Outer+Mid+Inner", "NS")
Args:
full_name: Fully qualified type name
Returns:
Tuple of (type_name, namespace_or_none)
"""
# Find the last dot that's not inside a nested type (before any +)
# For "NS.Outer+Nested", we want to split at the dot, not after +
plus_idx = full_name.find("+")
if plus_idx != -1:
# Has nested types - only look for namespace separator before the +
prefix = full_name[:plus_idx]
suffix = full_name[plus_idx:] # includes the +
dot_idx = prefix.rfind(".")
if dot_idx != -1:
namespace = prefix[:dot_idx]
name = prefix[dot_idx + 1:] + suffix
return name, namespace
else:
# No namespace, just nested types
return full_name, None
else:
# No nested types - simple split on last dot
dot_idx = full_name.rfind(".")
if dot_idx != -1:
return full_name[dot_idx + 1:], full_name[:dot_idx]
else:
return full_name, None
async def generate_diagrammer(self, request: GenerateDiagrammerRequest) -> dict[str, Any]:
"""Generate HTML diagrammer for an assembly.
Args:
request: Generate diagrammer request
Returns:
Dictionary with success status and details
"""
if not os.path.exists(request.assembly_path):
return {
"success": False,
"error_message": f"Assembly file not found: {request.assembly_path}",
}
# Validate PE signature before invoking ilspycmd
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
if not is_valid:
return {"success": False, "error_message": pe_error}
args = [request.assembly_path, "--generate-diagrammer"]
# Add output directory
output_dir = request.output_dir
if not output_dir:
# Generate next to assembly
assembly_dir = os.path.dirname(request.assembly_path)
output_dir = os.path.join(assembly_dir, "diagrammer")
args.extend(["-o", output_dir])
# Add include/exclude patterns
if request.include_pattern:
args.extend(["--generate-diagrammer-include", request.include_pattern])
if request.exclude_pattern:
args.extend(["--generate-diagrammer-exclude", request.exclude_pattern])
# Add documentation file
if request.docs_path:
args.extend(["--generate-diagrammer-docs", request.docs_path])
# Add namespace stripping
if request.strip_namespaces:
args.extend(["--generate-diagrammer-strip-namespaces"] + request.strip_namespaces)
# Add report excluded flag
if request.report_excluded:
args.append("--generate-diagrammer-report-excluded")
# Disable update check
args.append("--disable-updatecheck")
try:
return_code, stdout, stderr = await self._run_command(args)
if return_code == 0:
return {
"success": True,
"output_directory": output_dir,
"message": "HTML diagrammer generated successfully",
}
else:
error_msg = stderr or stdout or "Unknown error occurred"
return {"success": False, "error_message": error_msg}
except OSError as e:
logger.exception(f"Error generating diagrammer: {e}")
return {"success": False, "error_message": str(e)}
async def get_assembly_info(self, request: AssemblyInfoRequest) -> AssemblyInfo:
"""Get detailed information about an assembly by decompiling assembly attributes.
Args:
request: Assembly info request
Returns:
Assembly information including version, target framework, etc.
"""
if not os.path.exists(request.assembly_path):
raise FileNotFoundError(f"Assembly file not found: {request.assembly_path}")
# Validate PE signature before invoking ilspycmd
is_valid, pe_error = _validate_pe_signature(request.assembly_path)
if not is_valid:
raise ValueError(pe_error)
assembly_path = Path(request.assembly_path)
# Use ilspycmd to list types and extract assembly info from output
args = [request.assembly_path, "-l", "c", "--disable-updatecheck"]
return_code, stdout, stderr = await self._run_command(args)
# Initialize with defaults
name = assembly_path.stem
version = "Unknown"
full_name = assembly_path.name
target_framework = None
runtime_version = None
is_signed = False
# Try to extract more info by decompiling assembly attributes
# Decompile with minimal output to get assembly-level attributes
# Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
with tempfile.TemporaryDirectory() as temp_dir:
args = [
request.assembly_path,
"-o",
temp_dir,
"-lv",
"Latest",
"--disable-updatecheck",
]
return_code, stdout, stderr = await self._run_command(args)
if return_code == 0:
# Look for AssemblyInfo or assembly attributes in output
# Parse common patterns from decompiled code
output_text = stdout
# Try to read main decompiled file for assembly attributes
main_file = os.path.join(temp_dir, f"{name}.cs")
if os.path.exists(main_file):
with open(main_file, encoding="utf-8") as f:
output_text = f.read()
# Extract version from AssemblyVersion attribute
version_match = re.search(
r'\[assembly:\s*AssemblyVersion\s*\(\s*"([^"]+)"\s*\)\s*\]', output_text
)
if version_match:
version = version_match.group(1)
# Extract file version as fallback
if version == "Unknown":
file_version_match = re.search(
r'\[assembly:\s*AssemblyFileVersion\s*\(\s*"([^"]+)"\s*\)\s*\]', output_text
)
if file_version_match:
version = file_version_match.group(1)
# Extract target framework
framework_match = re.search(
r'\[assembly:\s*TargetFramework\s*\(\s*"([^"]+)"', output_text
)
if framework_match:
target_framework = framework_match.group(1)
# Check for signing
is_signed = (
"[assembly: AssemblyKeyFile" in output_text
or "[assembly: AssemblyDelaySign" in output_text
or "PublicKeyToken=" in output_text
)
# Extract full name from assembly title or product
title_match = re.search(
r'\[assembly:\s*AssemblyTitle\s*\(\s*"([^"]+)"\s*\)\s*\]', output_text
)
if title_match:
full_name = f"{title_match.group(1)}, Version={version}"
return AssemblyInfo(
name=name,
version=version,
full_name=full_name,
location=str(assembly_path.absolute()),
target_framework=target_framework,
runtime_version=runtime_version,
is_signed=is_signed,
has_debug_info=os.path.exists(assembly_path.with_suffix(".pdb")),
)