security: path validation, temp cleanup, output limits (S1-S4)

This commit is contained in:
Ryan Malloy 2026-02-08 11:31:00 -07:00
commit 3b6afd0646
5 changed files with 521 additions and 47 deletions

View File

@ -2,10 +2,10 @@
"project": "mcilspy-code-review-fixes", "project": "mcilspy-code-review-fixes",
"created": "2025-02-08T00:00:00Z", "created": "2025-02-08T00:00:00Z",
"domains": { "domains": {
"security": { "status": "pending", "branch": "fix/security", "priority": 1 }, "security": { "status": "ready", "branch": "fix/security", "priority": 1 },
"architecture": { "status": "pending", "branch": "fix/architecture", "priority": 2 }, "architecture": { "status": "ready", "branch": "fix/architecture", "priority": 2 },
"performance": { "status": "pending", "branch": "fix/performance", "priority": 3 }, "performance": { "status": "ready", "branch": "fix/performance", "priority": 3 },
"testing": { "status": "pending", "branch": "fix/testing", "priority": 4 } "testing": { "status": "ready", "branch": "fix/testing", "priority": 4 }
}, },
"merge_order": ["security", "architecture", "performance", "testing"] "merge_order": ["security", "architecture", "performance", "testing"]
} }

View File

@ -22,6 +22,10 @@ from .models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum bytes to read from subprocess stdout/stderr to prevent memory exhaustion
# from malicious or corrupted assemblies that produce huge output
MAX_OUTPUT_BYTES = 50_000_000 # 50 MB
class ILSpyWrapper: class ILSpyWrapper:
"""Wrapper class for ILSpy command line tool.""" """Wrapper class for ILSpy command line tool."""
@ -85,6 +89,10 @@ class ILSpyWrapper:
Returns: Returns:
Tuple of (return_code, stdout, stderr) Tuple of (return_code, stdout, stderr)
Note:
Output is truncated to MAX_OUTPUT_BYTES to prevent memory exhaustion
from malicious or corrupted assemblies.
""" """
cmd = [self.ilspycmd_path] + args cmd = [self.ilspycmd_path] + args
logger.debug(f"Running command: {' '.join(cmd)}") logger.debug(f"Running command: {' '.join(cmd)}")
@ -111,9 +119,33 @@ class ILSpyWrapper:
await process.wait() # Ensure process is cleaned up await process.wait() # Ensure process is cleaned up
return -1, "", "Command timed out after 5 minutes. The assembly may be corrupted or too complex." return -1, "", "Command timed out after 5 minutes. The assembly may be corrupted or too complex."
# Truncate output if it exceeds the limit to prevent memory exhaustion
stdout_truncated = False
stderr_truncated = False
if stdout_bytes and len(stdout_bytes) > MAX_OUTPUT_BYTES:
stdout_bytes = stdout_bytes[:MAX_OUTPUT_BYTES]
stdout_truncated = True
logger.warning(
f"stdout truncated from {len(stdout_bytes)} to {MAX_OUTPUT_BYTES} bytes"
)
if stderr_bytes and len(stderr_bytes) > MAX_OUTPUT_BYTES:
stderr_bytes = stderr_bytes[:MAX_OUTPUT_BYTES]
stderr_truncated = True
logger.warning(
f"stderr truncated from {len(stderr_bytes)} to {MAX_OUTPUT_BYTES} bytes"
)
stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else "" stdout = stdout_bytes.decode("utf-8", errors="replace") if stdout_bytes else ""
stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else "" stderr = stderr_bytes.decode("utf-8", errors="replace") if stderr_bytes else ""
# Add truncation warning to output
if stdout_truncated:
stdout += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
if stderr_truncated:
stderr += "\n\n[OUTPUT TRUNCATED - exceeded 50MB limit]"
return process.returncode, stdout, stderr return process.returncode, stdout, stderr
except Exception as e: except Exception as e:
@ -136,6 +168,28 @@ class ILSpyWrapper:
assembly_name=os.path.basename(request.assembly_path), assembly_name=os.path.basename(request.assembly_path),
) )
# Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
# when user doesn't specify an output directory
if request.output_dir:
# User specified output directory - use it directly
return await self._decompile_to_dir(request, request.output_dir)
else:
# Create a temporary directory with guaranteed cleanup
with tempfile.TemporaryDirectory() as temp_dir:
return await self._decompile_to_dir(request, temp_dir)
async def _decompile_to_dir(
self, request: DecompileRequest, output_dir: str
) -> DecompileResponse:
"""Internal helper to perform decompilation to a specific directory.
Args:
request: Decompilation request
output_dir: Directory to write output to
Returns:
Decompilation response
"""
args = [request.assembly_path] args = [request.assembly_path]
# Add language version # Add language version
@ -145,13 +199,6 @@ class ILSpyWrapper:
if request.type_name: if request.type_name:
args.extend(["-t", request.type_name]) args.extend(["-t", request.type_name])
# Add output directory if specified
temp_dir = None
output_dir = request.output_dir
if not output_dir:
temp_dir = tempfile.mkdtemp()
output_dir = temp_dir
args.extend(["-o", output_dir]) args.extend(["-o", output_dir])
# Add project creation flag # Add project creation flag
@ -190,7 +237,7 @@ class ILSpyWrapper:
assembly_name = os.path.splitext(os.path.basename(request.assembly_path))[0] assembly_name = os.path.splitext(os.path.basename(request.assembly_path))[0]
if return_code == 0: if return_code == 0:
# If no output directory was specified, return stdout as source code # If no output directory was specified by user, return stdout as source code
source_code = None source_code = None
output_path = None output_path = None
@ -237,10 +284,6 @@ class ILSpyWrapper:
assembly_name=os.path.basename(request.assembly_path), assembly_name=os.path.basename(request.assembly_path),
type_name=request.type_name, type_name=request.type_name,
) )
finally:
# Clean up temporary directory
if temp_dir and os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
async def list_types(self, request: ListTypesRequest) -> ListTypesResponse: async def list_types(self, request: ListTypesRequest) -> ListTypesResponse:
"""List types in a .NET assembly. """List types in a .NET assembly.
@ -464,8 +507,8 @@ class ILSpyWrapper:
# Try to extract more info by decompiling assembly attributes # Try to extract more info by decompiling assembly attributes
# Decompile with minimal output to get assembly-level attributes # Decompile with minimal output to get assembly-level attributes
temp_dir = tempfile.mkdtemp() # Use TemporaryDirectory context manager for guaranteed cleanup (no race condition)
try: with tempfile.TemporaryDirectory() as temp_dir:
args = [ args = [
request.assembly_path, request.assembly_path,
"-o", "-o",
@ -523,10 +566,6 @@ class ILSpyWrapper:
if title_match: if title_match:
full_name = f"{title_match.group(1)}, Version={version}" full_name = f"{title_match.group(1)}, Version={version}"
finally:
# Clean up temp directory
shutil.rmtree(temp_dir, ignore_errors=True)
return AssemblyInfo( return AssemblyInfo(
name=name, name=name,
version=version, version=version,

View File

@ -18,6 +18,10 @@ from dnfile.mdtable import TypeDefRow
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum assembly file size to load (in megabytes)
# Prevents memory exhaustion from extremely large or malicious assemblies
MAX_ASSEMBLY_SIZE_MB = 500
@dataclass @dataclass
class MethodInfo: class MethodInfo:
@ -101,6 +105,12 @@ class AssemblyMetadata:
referenced_assemblies: list[str] = field(default_factory=list) referenced_assemblies: list[str] = field(default_factory=list)
class AssemblySizeError(ValueError):
"""Raised when an assembly exceeds the maximum allowed size."""
pass
class MetadataReader: class MetadataReader:
"""Read .NET assembly metadata directly using dnfile.""" """Read .NET assembly metadata directly using dnfile."""
@ -109,11 +119,25 @@ class MetadataReader:
Args: Args:
assembly_path: Path to the .NET assembly file assembly_path: Path to the .NET assembly file
Raises:
FileNotFoundError: If the assembly file doesn't exist
AssemblySizeError: If the assembly exceeds MAX_ASSEMBLY_SIZE_MB
""" """
self.assembly_path = Path(assembly_path) self.assembly_path = Path(assembly_path)
if not self.assembly_path.exists(): if not self.assembly_path.exists():
raise FileNotFoundError(f"Assembly not found: {assembly_path}") raise FileNotFoundError(f"Assembly not found: {assembly_path}")
# Check file size before loading to prevent memory exhaustion
file_size_bytes = self.assembly_path.stat().st_size
max_size_bytes = MAX_ASSEMBLY_SIZE_MB * 1024 * 1024
if file_size_bytes > max_size_bytes:
size_mb = file_size_bytes / (1024 * 1024)
raise AssemblySizeError(
f"Assembly file size ({size_mb:.1f} MB) exceeds maximum allowed "
f"({MAX_ASSEMBLY_SIZE_MB} MB). This limit prevents memory exhaustion."
)
self._pe: dnfile.dnPE | None = None self._pe: dnfile.dnPE | None = None
self._type_cache: dict[int, TypeDefRow] = {} self._type_cache: dict[int, TypeDefRow] = {}

View File

@ -96,6 +96,57 @@ def _format_error(error: Exception, context: str = "") -> str:
return f"**Error**: {error_msg}" return f"**Error**: {error_msg}"
class AssemblyPathError(ValueError):
"""Raised when an assembly path fails validation."""
pass
def _validate_assembly_path(assembly_path: str) -> str:
"""Validate and normalize an assembly path for security.
Performs the following checks:
1. Path is not empty
2. Resolves to an absolute path (prevents path traversal)
3. File exists and is a regular file (not a directory or symlink to directory)
4. Has a valid .NET assembly extension (.dll or .exe)
Args:
assembly_path: User-provided path to a .NET assembly
Returns:
Absolute, validated path to the assembly
Raises:
AssemblyPathError: If the path fails any validation check
"""
if not assembly_path or not assembly_path.strip():
raise AssemblyPathError("Assembly path cannot be empty")
# Resolve to absolute path (handles .., symlinks, etc.)
try:
resolved_path = os.path.realpath(os.path.expanduser(assembly_path.strip()))
except (OSError, ValueError) as e:
raise AssemblyPathError(f"Invalid path: {e}") from e
# Check if path exists
if not os.path.exists(resolved_path):
raise AssemblyPathError(f"Assembly file not found: {resolved_path}")
# Check if it's a regular file (not a directory)
if not os.path.isfile(resolved_path):
raise AssemblyPathError(f"Path is not a file: {resolved_path}")
# Validate extension
_, ext = os.path.splitext(resolved_path)
if ext.lower() not in (".dll", ".exe"):
raise AssemblyPathError(
f"Invalid assembly extension '{ext}'. Expected .dll or .exe"
)
return resolved_path
def _find_ilspycmd_path() -> str | None: def _find_ilspycmd_path() -> str | None:
"""Find ilspycmd in PATH or common install locations.""" """Find ilspycmd in PATH or common install locations."""
# Check PATH first # Check PATH first
@ -562,8 +613,14 @@ async def decompile_assembly(
show_il_sequence_points: Include debugging sequence points in IL output (implies show_il_code) show_il_sequence_points: Include debugging sequence points in IL output (implies show_il_code)
nested_directories: Organize output files in namespace-based directory hierarchy nested_directories: Organize output files in namespace-based directory hierarchy
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Starting decompilation of assembly: {assembly_path}") await ctx.info(f"Starting decompilation of assembly: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
@ -572,7 +629,7 @@ async def decompile_assembly(
from .models import DecompileRequest from .models import DecompileRequest
request = DecompileRequest( request = DecompileRequest(
assembly_path=assembly_path, assembly_path=validated_path,
output_dir=output_dir, output_dir=output_dir,
type_name=type_name, type_name=type_name,
language_version=LanguageVersion(language_version), language_version=LanguageVersion(language_version),
@ -633,8 +690,14 @@ async def list_types(
- "enum" or "e" - "enum" or "e"
Example: ["class", "interface"] or ["c", "i"] Example: ["class", "interface"] or ["c", "i"]
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Listing types in assembly: {assembly_path}") await ctx.info(f"Listing types in assembly: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
@ -654,12 +717,12 @@ async def list_types(
from .models import ListTypesRequest from .models import ListTypesRequest
request = ListTypesRequest(assembly_path=assembly_path, entity_types=entity_type_enums) request = ListTypesRequest(assembly_path=validated_path, entity_types=entity_type_enums)
response = await wrapper.list_types(request) response = await wrapper.list_types(request)
if response.success and response.types: if response.success and response.types:
content = f"# Types in {assembly_path}\n\n" content = f"# Types in {validated_path}\n\n"
content += f"Found {response.total_count} types:\n\n" content += f"Found {response.total_count} types:\n\n"
# Group by namespace # Group by namespace
@ -713,8 +776,14 @@ async def generate_diagrammer(
include_pattern: Regex to whitelist types (e.g., "MyApp\\\\.Services\\\\..+" for Services namespace) include_pattern: Regex to whitelist types (e.g., "MyApp\\\\.Services\\\\..+" for Services namespace)
exclude_pattern: Regex to blacklist types (e.g., ".*Generated.*" to hide generated code) exclude_pattern: Regex to blacklist types (e.g., ".*Generated.*" to hide generated code)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Generating assembly diagram: {assembly_path}") await ctx.info(f"Generating assembly diagram: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
@ -722,7 +791,7 @@ async def generate_diagrammer(
from .models import GenerateDiagrammerRequest from .models import GenerateDiagrammerRequest
request = GenerateDiagrammerRequest( request = GenerateDiagrammerRequest(
assembly_path=assembly_path, assembly_path=validated_path,
output_dir=output_dir, output_dir=output_dir,
include_pattern=include_pattern, include_pattern=include_pattern,
exclude_pattern=exclude_pattern, exclude_pattern=exclude_pattern,
@ -756,15 +825,21 @@ async def get_assembly_info(assembly_path: str, ctx: Context | None = None) -> s
Args: Args:
assembly_path: Full path to the .NET assembly file (.dll or .exe) assembly_path: Full path to the .NET assembly file (.dll or .exe)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Getting assembly info: {assembly_path}") await ctx.info(f"Getting assembly info: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
from .models import AssemblyInfoRequest from .models import AssemblyInfoRequest
request = AssemblyInfoRequest(assembly_path=assembly_path) request = AssemblyInfoRequest(assembly_path=validated_path)
info = await wrapper.get_assembly_info(request) info = await wrapper.get_assembly_info(request)
@ -817,8 +892,14 @@ async def search_types(
case_sensitive: Whether pattern matching is case-sensitive (default: False) case_sensitive: Whether pattern matching is case-sensitive (default: False)
use_regex: Treat pattern as regular expression (default: False) use_regex: Treat pattern as regular expression (default: False)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Searching for types matching '{pattern}' in: {assembly_path}") await ctx.info(f"Searching for types matching '{pattern}' in: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
@ -838,7 +919,7 @@ async def search_types(
from .models import ListTypesRequest from .models import ListTypesRequest
request = ListTypesRequest(assembly_path=assembly_path, entity_types=entity_type_enums) request = ListTypesRequest(assembly_path=validated_path, entity_types=entity_type_enums)
response = await wrapper.list_types(request) response = await wrapper.list_types(request)
if not response.success: if not response.success:
@ -936,8 +1017,14 @@ async def search_strings(
use_regex: Treat pattern as regular expression (default: False) use_regex: Treat pattern as regular expression (default: False)
max_results: Maximum number of matches to return (default: 100) max_results: Maximum number of matches to return (default: 100)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Searching for strings matching '{pattern}' in: {assembly_path}") await ctx.info(f"Searching for strings matching '{pattern}' in: {validated_path}")
try: try:
wrapper = get_wrapper(ctx) wrapper = get_wrapper(ctx)
@ -946,7 +1033,7 @@ async def search_strings(
from .models import DecompileRequest from .models import DecompileRequest
request = DecompileRequest( request = DecompileRequest(
assembly_path=assembly_path, assembly_path=validated_path,
show_il_code=True, # IL makes string literals explicit show_il_code=True, # IL makes string literals explicit
language_version=LanguageVersion.LATEST, language_version=LanguageVersion.LATEST,
) )
@ -1081,13 +1168,19 @@ async def search_methods(
case_sensitive: Whether pattern matching is case-sensitive (default: False) case_sensitive: Whether pattern matching is case-sensitive (default: False)
use_regex: Treat pattern as regular expression (default: False) use_regex: Treat pattern as regular expression (default: False)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Searching for methods matching '{pattern}' in: {assembly_path}") await ctx.info(f"Searching for methods matching '{pattern}' in: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
methods = reader.list_methods( methods = reader.list_methods(
type_filter=type_filter, type_filter=type_filter,
namespace_filter=namespace_filter, namespace_filter=namespace_filter,
@ -1198,13 +1291,19 @@ async def search_fields(
case_sensitive: Whether pattern matching is case-sensitive (default: False) case_sensitive: Whether pattern matching is case-sensitive (default: False)
use_regex: Treat pattern as regular expression (default: False) use_regex: Treat pattern as regular expression (default: False)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Searching for fields matching '{pattern}' in: {assembly_path}") await ctx.info(f"Searching for fields matching '{pattern}' in: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
fields = reader.list_fields( fields = reader.list_fields(
type_filter=type_filter, type_filter=type_filter,
namespace_filter=namespace_filter, namespace_filter=namespace_filter,
@ -1307,13 +1406,19 @@ async def search_properties(
case_sensitive: Whether pattern matching is case-sensitive (default: False) case_sensitive: Whether pattern matching is case-sensitive (default: False)
use_regex: Treat pattern as regular expression (default: False) use_regex: Treat pattern as regular expression (default: False)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Searching for properties matching '{pattern}' in: {assembly_path}") await ctx.info(f"Searching for properties matching '{pattern}' in: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
properties = reader.list_properties( properties = reader.list_properties(
type_filter=type_filter, type_filter=type_filter,
namespace_filter=namespace_filter, namespace_filter=namespace_filter,
@ -1398,13 +1503,19 @@ async def list_events(
type_filter: Only list events in types containing this string type_filter: Only list events in types containing this string
namespace_filter: Only list events in namespaces containing this string namespace_filter: Only list events in namespaces containing this string
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Listing events in: {assembly_path}") await ctx.info(f"Listing events in: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
events = reader.list_events( events = reader.list_events(
type_filter=type_filter, type_filter=type_filter,
namespace_filter=namespace_filter, namespace_filter=namespace_filter,
@ -1455,13 +1566,19 @@ async def list_resources(
Args: Args:
assembly_path: Full path to the .NET assembly file (.dll or .exe) assembly_path: Full path to the .NET assembly file (.dll or .exe)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Listing resources in: {assembly_path}") await ctx.info(f"Listing resources in: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
resources = reader.list_resources() resources = reader.list_resources()
if not resources: if not resources:
@ -1502,13 +1619,19 @@ async def get_metadata_summary(
Args: Args:
assembly_path: Full path to the .NET assembly file (.dll or .exe) assembly_path: Full path to the .NET assembly file (.dll or .exe)
""" """
# Validate assembly path before any processing
try:
validated_path = _validate_assembly_path(assembly_path)
except AssemblyPathError as e:
return _format_error(e, "path validation")
if ctx: if ctx:
await ctx.info(f"Getting metadata summary: {assembly_path}") await ctx.info(f"Getting metadata summary: {validated_path}")
try: try:
from .metadata_reader import MetadataReader from .metadata_reader import MetadataReader
with MetadataReader(assembly_path) as reader: with MetadataReader(validated_path) as reader:
meta = reader.get_assembly_metadata() meta = reader.get_assembly_metadata()
content = "# Assembly Metadata Summary\n\n" content = "# Assembly Metadata Summary\n\n"

288
tests/test_security.py Normal file
View File

@ -0,0 +1,288 @@
"""Security-focused tests for mcilspy validation functions.
These tests verify the security hardening in S1-S4:
- S1: Path traversal prevention via _validate_assembly_path()
- S2: Temp directory race condition fix (structural - uses TemporaryDirectory)
- S3: Subprocess output size limits (MAX_OUTPUT_BYTES)
- S4: Assembly file size limits (MAX_ASSEMBLY_SIZE_MB)
"""
import os
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
class TestValidateAssemblyPath:
"""Tests for S1: _validate_assembly_path() security validation."""
def test_empty_path_rejected(self):
"""Empty paths should be rejected."""
from mcilspy.server import AssemblyPathError, _validate_assembly_path
with pytest.raises(AssemblyPathError, match="cannot be empty"):
_validate_assembly_path("")
with pytest.raises(AssemblyPathError, match="cannot be empty"):
_validate_assembly_path(" ")
def test_nonexistent_file_rejected(self):
"""Non-existent files should be rejected."""
from mcilspy.server import AssemblyPathError, _validate_assembly_path
with pytest.raises(AssemblyPathError, match="not found"):
_validate_assembly_path("/nonexistent/path/to/assembly.dll")
def test_directory_rejected(self):
"""Directories should be rejected (file required)."""
from mcilspy.server import AssemblyPathError, _validate_assembly_path
with pytest.raises(AssemblyPathError, match="not a file"):
_validate_assembly_path("/tmp")
def test_invalid_extension_rejected(self):
"""Files without .dll or .exe extension should be rejected."""
from mcilspy.server import AssemblyPathError, _validate_assembly_path
# Create a temp file with wrong extension
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
temp_path = f.name
try:
with pytest.raises(AssemblyPathError, match="Invalid assembly extension"):
_validate_assembly_path(temp_path)
finally:
os.unlink(temp_path)
def test_valid_dll_accepted(self):
"""Valid .dll files should be accepted."""
from mcilspy.server import _validate_assembly_path
# Create a temp file with .dll extension
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
temp_path = f.name
try:
result = _validate_assembly_path(temp_path)
# Should return absolute path
assert os.path.isabs(result)
assert result.endswith(".dll")
finally:
os.unlink(temp_path)
def test_valid_exe_accepted(self):
"""Valid .exe files should be accepted."""
from mcilspy.server import _validate_assembly_path
# Create a temp file with .exe extension
with tempfile.NamedTemporaryFile(suffix=".exe", delete=False) as f:
temp_path = f.name
try:
result = _validate_assembly_path(temp_path)
# Should return absolute path
assert os.path.isabs(result)
assert result.endswith(".exe")
finally:
os.unlink(temp_path)
def test_case_insensitive_extension(self):
"""Extension check should be case-insensitive."""
from mcilspy.server import _validate_assembly_path
# Create temp files with mixed case extensions
for ext in [".DLL", ".Dll", ".EXE", ".Exe"]:
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
temp_path = f.name
try:
result = _validate_assembly_path(temp_path)
assert os.path.isabs(result)
finally:
os.unlink(temp_path)
def test_path_traversal_resolved(self):
"""Path traversal attempts should be resolved to absolute paths."""
from mcilspy.server import _validate_assembly_path
# Create a temp file
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
temp_path = f.name
try:
# Create a path with .. components
parent = os.path.dirname(temp_path)
filename = os.path.basename(temp_path)
traversal_path = os.path.join(parent, "..", os.path.basename(parent), filename)
result = _validate_assembly_path(traversal_path)
# Should resolve to absolute path without ..
assert ".." not in result
assert os.path.isabs(result)
finally:
os.unlink(temp_path)
def test_tilde_expansion(self):
"""Home directory tilde should be expanded."""
from mcilspy.server import AssemblyPathError, _validate_assembly_path
# This should fail because the file doesn't exist, but the path should be expanded
with pytest.raises(AssemblyPathError, match="not found"):
_validate_assembly_path("~/nonexistent.dll")
class TestMaxOutputBytes:
"""Tests for S3: MAX_OUTPUT_BYTES constant and truncation."""
def test_constant_defined(self):
"""MAX_OUTPUT_BYTES constant should be defined."""
from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES
assert MAX_OUTPUT_BYTES == 50_000_000 # 50 MB
assert MAX_OUTPUT_BYTES > 0
def test_constant_reasonable_size(self):
"""MAX_OUTPUT_BYTES should be a reasonable size (not too small, not too large)."""
from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES
# Should be at least 1 MB
assert MAX_OUTPUT_BYTES >= 1_000_000
# Should be at most 100 MB to prevent memory issues
assert MAX_OUTPUT_BYTES <= 100_000_000
class TestMaxAssemblySize:
"""Tests for S4: MAX_ASSEMBLY_SIZE_MB constant and size check."""
def test_constant_defined(self):
"""MAX_ASSEMBLY_SIZE_MB constant should be defined."""
from mcilspy.metadata_reader import MAX_ASSEMBLY_SIZE_MB
assert MAX_ASSEMBLY_SIZE_MB == 500 # 500 MB
assert MAX_ASSEMBLY_SIZE_MB > 0
def test_assembly_size_error_defined(self):
"""AssemblySizeError exception should be defined."""
from mcilspy.metadata_reader import AssemblySizeError
assert issubclass(AssemblySizeError, ValueError)
def test_oversized_file_rejected(self):
"""Files exceeding MAX_ASSEMBLY_SIZE_MB should be rejected."""
from mcilspy.metadata_reader import (
AssemblySizeError,
MAX_ASSEMBLY_SIZE_MB,
MetadataReader,
)
# Create a temp file
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
temp_path = f.name
try:
# Mock the file size to be over the limit
mock_stat_result = os.stat(temp_path)
oversized_bytes = (MAX_ASSEMBLY_SIZE_MB + 1) * 1024 * 1024
with patch.object(Path, "stat") as mock_stat:
mock_stat.return_value = type(
"StatResult",
(),
{"st_size": oversized_bytes},
)()
with pytest.raises(AssemblySizeError, match="exceeds maximum"):
MetadataReader(temp_path)
finally:
os.unlink(temp_path)
def test_normal_sized_file_accepted(self):
"""Files under MAX_ASSEMBLY_SIZE_MB should be accepted (at init)."""
from mcilspy.metadata_reader import MetadataReader
# Create a small temp file
with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f:
f.write(b"small content")
temp_path = f.name
try:
# Should not raise on init (will fail later when trying to parse)
reader = MetadataReader(temp_path)
assert reader.assembly_path.exists()
finally:
os.unlink(temp_path)
class TestTemporaryDirectoryUsage:
"""Tests for S2: TemporaryDirectory context manager usage.
These are structural tests that verify the code uses the secure pattern.
"""
def test_decompile_uses_temp_directory_context(self):
"""Verify decompile method structure uses TemporaryDirectory."""
import inspect
from mcilspy.ilspy_wrapper import ILSpyWrapper
# Get source code of decompile method
source = inspect.getsource(ILSpyWrapper.decompile)
# Should use TemporaryDirectory context manager
assert "tempfile.TemporaryDirectory()" in source
assert "with tempfile.TemporaryDirectory()" in source
# Should NOT use the old mkdtemp pattern
assert "tempfile.mkdtemp()" not in source
def test_get_assembly_info_uses_temp_directory_context(self):
"""Verify get_assembly_info method structure uses TemporaryDirectory."""
import inspect
from mcilspy.ilspy_wrapper import ILSpyWrapper
# Get source code of get_assembly_info method
source = inspect.getsource(ILSpyWrapper.get_assembly_info)
# Should use TemporaryDirectory context manager
assert "with tempfile.TemporaryDirectory()" in source
# Should NOT use the old mkdtemp pattern
assert "tempfile.mkdtemp()" not in source
class TestPathValidationIntegration:
"""Integration tests to verify path validation is applied to all tools."""
def test_all_tools_have_path_validation(self):
"""Verify all assembly-accepting tools call _validate_assembly_path."""
import inspect
from mcilspy import server
# List of tool functions that accept assembly_path
tools_with_assembly_path = [
"decompile_assembly",
"list_types",
"generate_diagrammer",
"get_assembly_info",
"search_types",
"search_strings",
"search_methods",
"search_fields",
"search_properties",
"list_events",
"list_resources",
"get_metadata_summary",
]
for tool_name in tools_with_assembly_path:
tool_func = getattr(server, tool_name)
source = inspect.getsource(tool_func)
# Each tool should call _validate_assembly_path
assert "_validate_assembly_path" in source, (
f"Tool '{tool_name}' does not call _validate_assembly_path"
)