""" Test middleware system """ from unittest.mock import AsyncMock import pytest from mcp_esptool_server.middleware import LoggerInterceptor, MiddlewareFactory class MockContext: """Mock FastMCP context for testing""" def __init__(self): self.log = AsyncMock() self.progress = AsyncMock() self.request_user_input = AsyncMock() self.sample = AsyncMock() def test_middleware_factory_supported_tools(): """Test middleware factory tool support""" supported = MiddlewareFactory.get_supported_tools() assert isinstance(supported, dict) assert "esptool" in supported assert isinstance(supported["esptool"], str) def test_middleware_factory_tool_support_check(): """Test tool support checking""" assert MiddlewareFactory.is_tool_supported("esptool") assert not MiddlewareFactory.is_tool_supported("nonexistent_tool") def test_middleware_factory_create_esptool(): """Test ESPTool middleware creation""" context = MockContext() middleware = MiddlewareFactory.create_esptool_middleware(context) assert middleware is not None assert middleware.context == context assert middleware.operation_id.startswith("esptool_") def test_middleware_factory_unsupported_tool(): """Test error handling for unsupported tools""" context = MockContext() with pytest.raises(Exception): # Should raise ToolNotFoundError MiddlewareFactory.create_middleware("unsupported_tool", context) def test_middleware_info(): """Test middleware information retrieval""" info = MiddlewareFactory.get_middleware_info("esptool") assert isinstance(info, dict) assert info["tool_name"] == "esptool" assert "middleware_class" in info assert "description" in info def test_logger_interceptor_capabilities(): """Test logger interceptor capability detection""" context = MockContext() # Create a concrete implementation for testing class TestInterceptor(LoggerInterceptor): async def install_hooks(self): pass async def remove_hooks(self): pass def get_interaction_points(self): return ["test_operation"] interceptor = TestInterceptor(context, "test_op") assert interceptor.capabilities["logging"] is True assert interceptor.capabilities["progress"] is True assert interceptor.capabilities["elicitation"] is True @pytest.mark.asyncio async def test_logger_interceptor_logging(): """Test logger interceptor logging methods""" context = MockContext() class TestInterceptor(LoggerInterceptor): async def install_hooks(self): pass async def remove_hooks(self): pass def get_interaction_points(self): return [] interceptor = TestInterceptor(context, "test_op") # Test logging methods await interceptor._log_info("Test info message") await interceptor._log_warning("Test warning") await interceptor._log_error("Test error") await interceptor._log_success("Test success") # Verify context.log was called assert context.log.call_count == 4 @pytest.mark.asyncio async def test_logger_interceptor_progress(): """Test logger interceptor progress tracking""" context = MockContext() class TestInterceptor(LoggerInterceptor): async def install_hooks(self): pass async def remove_hooks(self): pass def get_interaction_points(self): return [] interceptor = TestInterceptor(context, "test_op") # Test progress update await interceptor._update_progress(50, "Half complete") # Verify context.progress was called context.progress.assert_called_once() # Check progress history assert len(interceptor.progress_history) == 1 assert interceptor.progress_history[0]["percentage"] == 50