- Rename mcp-esptool-server -> mcesptool - Update all imports and references - Single entry point: mcesptool command - New home: git.supported.systems/MCP/mcesptool
141 lines
3.8 KiB
Python
141 lines
3.8 KiB
Python
"""
|
|
Test middleware system
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from mcesptool.middleware import LoggerInterceptor, MiddlewareFactory
|
|
|
|
|
|
class MockContext:
|
|
"""Mock FastMCP context for testing"""
|
|
|
|
def __init__(self):
|
|
self.log = AsyncMock()
|
|
self.progress = AsyncMock()
|
|
self.request_user_input = AsyncMock()
|
|
self.sample = AsyncMock()
|
|
|
|
|
|
def test_middleware_factory_supported_tools():
|
|
"""Test middleware factory tool support"""
|
|
supported = MiddlewareFactory.get_supported_tools()
|
|
|
|
assert isinstance(supported, dict)
|
|
assert "esptool" in supported
|
|
assert isinstance(supported["esptool"], str)
|
|
|
|
|
|
def test_middleware_factory_tool_support_check():
|
|
"""Test tool support checking"""
|
|
assert MiddlewareFactory.is_tool_supported("esptool")
|
|
assert not MiddlewareFactory.is_tool_supported("nonexistent_tool")
|
|
|
|
|
|
def test_middleware_factory_create_esptool():
|
|
"""Test ESPTool middleware creation"""
|
|
context = MockContext()
|
|
|
|
middleware = MiddlewareFactory.create_esptool_middleware(context)
|
|
|
|
assert middleware is not None
|
|
assert middleware.context == context
|
|
assert middleware.operation_id.startswith("esptool_")
|
|
|
|
|
|
def test_middleware_factory_unsupported_tool():
|
|
"""Test error handling for unsupported tools"""
|
|
context = MockContext()
|
|
|
|
with pytest.raises(Exception): # Should raise ToolNotFoundError
|
|
MiddlewareFactory.create_middleware("unsupported_tool", context)
|
|
|
|
|
|
def test_middleware_info():
|
|
"""Test middleware information retrieval"""
|
|
info = MiddlewareFactory.get_middleware_info("esptool")
|
|
|
|
assert isinstance(info, dict)
|
|
assert info["tool_name"] == "esptool"
|
|
assert "middleware_class" in info
|
|
assert "description" in info
|
|
|
|
|
|
def test_logger_interceptor_capabilities():
|
|
"""Test logger interceptor capability detection"""
|
|
context = MockContext()
|
|
|
|
# Create a concrete implementation for testing
|
|
class TestInterceptor(LoggerInterceptor):
|
|
async def install_hooks(self):
|
|
pass
|
|
|
|
async def remove_hooks(self):
|
|
pass
|
|
|
|
def get_interaction_points(self):
|
|
return ["test_operation"]
|
|
|
|
interceptor = TestInterceptor(context, "test_op")
|
|
|
|
assert interceptor.capabilities["logging"] is True
|
|
assert interceptor.capabilities["progress"] is True
|
|
assert interceptor.capabilities["elicitation"] is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logger_interceptor_logging():
|
|
"""Test logger interceptor logging methods"""
|
|
context = MockContext()
|
|
|
|
class TestInterceptor(LoggerInterceptor):
|
|
async def install_hooks(self):
|
|
pass
|
|
|
|
async def remove_hooks(self):
|
|
pass
|
|
|
|
def get_interaction_points(self):
|
|
return []
|
|
|
|
interceptor = TestInterceptor(context, "test_op")
|
|
|
|
# Test logging methods
|
|
await interceptor._log_info("Test info message")
|
|
await interceptor._log_warning("Test warning")
|
|
await interceptor._log_error("Test error")
|
|
await interceptor._log_success("Test success")
|
|
|
|
# Verify context.log was called
|
|
assert context.log.call_count == 4
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logger_interceptor_progress():
|
|
"""Test logger interceptor progress tracking"""
|
|
context = MockContext()
|
|
|
|
class TestInterceptor(LoggerInterceptor):
|
|
async def install_hooks(self):
|
|
pass
|
|
|
|
async def remove_hooks(self):
|
|
pass
|
|
|
|
def get_interaction_points(self):
|
|
return []
|
|
|
|
interceptor = TestInterceptor(context, "test_op")
|
|
|
|
# Test progress update
|
|
await interceptor._update_progress(50, "Half complete")
|
|
|
|
# Verify context.progress was called
|
|
context.progress.assert_called_once()
|
|
|
|
# Check progress history
|
|
assert len(interceptor.progress_history) == 1
|
|
assert interceptor.progress_history[0]["percentage"] == 50
|