"""Security-focused tests for mcilspy validation functions. These tests verify the security hardening in S1-S4: - S1: Path traversal prevention via _validate_assembly_path() - S2: Temp directory race condition fix (structural - uses TemporaryDirectory) - S3: Subprocess output size limits (MAX_OUTPUT_BYTES) - S4: Assembly file size limits (MAX_ASSEMBLY_SIZE_MB) """ import os import tempfile from pathlib import Path from unittest.mock import patch import pytest class TestValidateAssemblyPath: """Tests for S1: _validate_assembly_path() security validation.""" def test_empty_path_rejected(self): """Empty paths should be rejected.""" from mcilspy.server import AssemblyPathError, _validate_assembly_path with pytest.raises(AssemblyPathError, match="cannot be empty"): _validate_assembly_path("") with pytest.raises(AssemblyPathError, match="cannot be empty"): _validate_assembly_path(" ") def test_nonexistent_file_rejected(self): """Non-existent files should be rejected.""" from mcilspy.server import AssemblyPathError, _validate_assembly_path with pytest.raises(AssemblyPathError, match="not found"): _validate_assembly_path("/nonexistent/path/to/assembly.dll") def test_directory_rejected(self): """Directories should be rejected (file required).""" from mcilspy.server import AssemblyPathError, _validate_assembly_path with pytest.raises(AssemblyPathError, match="not a file"): _validate_assembly_path("/tmp") def test_invalid_extension_rejected(self): """Files without .dll or .exe extension should be rejected.""" from mcilspy.server import AssemblyPathError, _validate_assembly_path # Create a temp file with wrong extension with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: temp_path = f.name try: with pytest.raises(AssemblyPathError, match="Invalid assembly extension"): _validate_assembly_path(temp_path) finally: os.unlink(temp_path) def test_valid_dll_accepted(self): """Valid .dll files should be accepted.""" from mcilspy.server import _validate_assembly_path # Create a temp file with .dll extension with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: temp_path = f.name try: result = _validate_assembly_path(temp_path) # Should return absolute path assert os.path.isabs(result) assert result.endswith(".dll") finally: os.unlink(temp_path) def test_valid_exe_accepted(self): """Valid .exe files should be accepted.""" from mcilspy.server import _validate_assembly_path # Create a temp file with .exe extension with tempfile.NamedTemporaryFile(suffix=".exe", delete=False) as f: temp_path = f.name try: result = _validate_assembly_path(temp_path) # Should return absolute path assert os.path.isabs(result) assert result.endswith(".exe") finally: os.unlink(temp_path) def test_case_insensitive_extension(self): """Extension check should be case-insensitive.""" from mcilspy.server import _validate_assembly_path # Create temp files with mixed case extensions for ext in [".DLL", ".Dll", ".EXE", ".Exe"]: with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f: temp_path = f.name try: result = _validate_assembly_path(temp_path) assert os.path.isabs(result) finally: os.unlink(temp_path) def test_path_traversal_resolved(self): """Path traversal attempts should be resolved to absolute paths.""" from mcilspy.server import _validate_assembly_path # Create a temp file with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: temp_path = f.name try: # Create a path with .. components parent = os.path.dirname(temp_path) filename = os.path.basename(temp_path) traversal_path = os.path.join(parent, "..", os.path.basename(parent), filename) result = _validate_assembly_path(traversal_path) # Should resolve to absolute path without .. assert ".." not in result assert os.path.isabs(result) finally: os.unlink(temp_path) def test_tilde_expansion(self): """Home directory tilde should be expanded.""" from mcilspy.server import AssemblyPathError, _validate_assembly_path # This should fail because the file doesn't exist, but the path should be expanded with pytest.raises(AssemblyPathError, match="not found"): _validate_assembly_path("~/nonexistent.dll") class TestMaxOutputBytes: """Tests for S3: MAX_OUTPUT_BYTES constant and truncation.""" def test_constant_defined(self): """MAX_OUTPUT_BYTES constant should be defined.""" from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES assert MAX_OUTPUT_BYTES == 50_000_000 # 50 MB assert MAX_OUTPUT_BYTES > 0 def test_constant_reasonable_size(self): """MAX_OUTPUT_BYTES should be a reasonable size (not too small, not too large).""" from mcilspy.ilspy_wrapper import MAX_OUTPUT_BYTES # Should be at least 1 MB assert MAX_OUTPUT_BYTES >= 1_000_000 # Should be at most 100 MB to prevent memory issues assert MAX_OUTPUT_BYTES <= 100_000_000 class TestMaxAssemblySize: """Tests for S4: MAX_ASSEMBLY_SIZE_MB constant and size check.""" def test_constant_defined(self): """MAX_ASSEMBLY_SIZE_MB constant should be defined.""" from mcilspy.metadata_reader import MAX_ASSEMBLY_SIZE_MB assert MAX_ASSEMBLY_SIZE_MB == 500 # 500 MB assert MAX_ASSEMBLY_SIZE_MB > 0 def test_assembly_size_error_defined(self): """AssemblySizeError exception should be defined.""" from mcilspy.metadata_reader import AssemblySizeError assert issubclass(AssemblySizeError, ValueError) def test_oversized_file_rejected(self): """Files exceeding MAX_ASSEMBLY_SIZE_MB should be rejected.""" from mcilspy.metadata_reader import ( AssemblySizeError, MAX_ASSEMBLY_SIZE_MB, MetadataReader, ) # Create a temp file with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: temp_path = f.name try: # Mock the file size to be over the limit mock_stat_result = os.stat(temp_path) oversized_bytes = (MAX_ASSEMBLY_SIZE_MB + 1) * 1024 * 1024 with patch.object(Path, "stat") as mock_stat: mock_stat.return_value = type( "StatResult", (), {"st_size": oversized_bytes}, )() with pytest.raises(AssemblySizeError, match="exceeds maximum"): MetadataReader(temp_path) finally: os.unlink(temp_path) def test_normal_sized_file_accepted(self): """Files under MAX_ASSEMBLY_SIZE_MB should be accepted (at init).""" from mcilspy.metadata_reader import MetadataReader # Create a small temp file with tempfile.NamedTemporaryFile(suffix=".dll", delete=False) as f: f.write(b"small content") temp_path = f.name try: # Should not raise on init (will fail later when trying to parse) reader = MetadataReader(temp_path) assert reader.assembly_path.exists() finally: os.unlink(temp_path) class TestTemporaryDirectoryUsage: """Tests for S2: TemporaryDirectory context manager usage. These are structural tests that verify the code uses the secure pattern. """ def test_decompile_uses_temp_directory_context(self): """Verify decompile method structure uses TemporaryDirectory.""" import inspect from mcilspy.ilspy_wrapper import ILSpyWrapper # Get source code of decompile method source = inspect.getsource(ILSpyWrapper.decompile) # Should use TemporaryDirectory context manager assert "tempfile.TemporaryDirectory()" in source assert "with tempfile.TemporaryDirectory()" in source # Should NOT use the old mkdtemp pattern assert "tempfile.mkdtemp()" not in source def test_get_assembly_info_uses_temp_directory_context(self): """Verify get_assembly_info method structure uses TemporaryDirectory.""" import inspect from mcilspy.ilspy_wrapper import ILSpyWrapper # Get source code of get_assembly_info method source = inspect.getsource(ILSpyWrapper.get_assembly_info) # Should use TemporaryDirectory context manager assert "with tempfile.TemporaryDirectory()" in source # Should NOT use the old mkdtemp pattern assert "tempfile.mkdtemp()" not in source class TestPathValidationIntegration: """Integration tests to verify path validation is applied to all tools.""" def test_all_tools_have_path_validation(self): """Verify all assembly-accepting tools call _validate_assembly_path.""" import inspect from mcilspy import server # List of tool functions that accept assembly_path tools_with_assembly_path = [ "decompile_assembly", "list_types", "generate_diagrammer", "get_assembly_info", "search_types", "search_strings", "search_methods", "search_fields", "search_properties", "list_events", "list_resources", "get_metadata_summary", ] for tool_name in tools_with_assembly_path: tool_func = getattr(server, tool_name) source = inspect.getsource(tool_func) # Each tool should call _validate_assembly_path assert "_validate_assembly_path" in source, ( f"Tool '{tool_name}' does not call _validate_assembly_path" )