"""Tests for skywalker-mcp server tools, validation, and safety. Calls async tool functions directly with a MockContext. No real USB hardware or MCP transport needed — tests validation, safety, and response formatting. """ import asyncio import pytest from skywalker_mcp.server import ( get_device_status as _get_device_status, get_signal_quality as _get_signal_quality, get_stream_diagnostics as _get_stream_diagnostics, sweep_spectrum as _sweep_spectrum, tune_frequency as _tune_frequency, run_blind_scan as _run_blind_scan, move_dish as _move_dish, jog_dish as _jog_dish, store_position as _store_position, set_lnb_config as _set_lnb_config, scan_i2c_bus as _scan_i2c_bus, read_i2c_register as _read_i2c_register, capture_transport_stream as _capture_transport_stream, identify_frequency as _identify_frequency, compare_surveys as _compare_surveys, mcp, MOTOR_WATCHDOG_SECS, ) # Unwrap FastMCP Tool objects → raw async functions for direct testing. # @mcp.tool() wraps each function as a Tool(fn=...) Pydantic model; # .fn gives us the original async def we can call with MockContext. get_device_status = _get_device_status.fn get_signal_quality = _get_signal_quality.fn get_stream_diagnostics = _get_stream_diagnostics.fn sweep_spectrum = _sweep_spectrum.fn tune_frequency = _tune_frequency.fn run_blind_scan = _run_blind_scan.fn move_dish = _move_dish.fn jog_dish = _jog_dish.fn store_position = _store_position.fn set_lnb_config = _set_lnb_config.fn scan_i2c_bus = _scan_i2c_bus.fn read_i2c_register = _read_i2c_register.fn capture_transport_stream = _capture_transport_stream.fn identify_frequency = _identify_frequency.fn compare_surveys = _compare_surveys.fn # ───────────────────────────────────────────── # Tool Registration # ───────────────────────────────────────────── def test_tool_count(): """17 tools should be registered.""" assert len(mcp._tool_manager._tools) == 17 def test_tool_names(): """All expected tool names are present.""" names = set(mcp._tool_manager._tools.keys()) expected = { "get_device_status", "get_signal_quality", "get_stream_diagnostics", "sweep_spectrum", "tune_frequency", "run_blind_scan", "run_carrier_survey", "compare_surveys", "list_surveys", "move_dish", "jog_dish", "store_position", "set_lnb_config", "scan_i2c_bus", "read_i2c_register", "capture_transport_stream", "identify_frequency", } assert expected == names def test_resource_count(): """4 resources registered.""" assert len(mcp._resource_manager._resources) == 4 def test_prompt_count(): """2 prompts registered.""" assert len(mcp._prompt_manager._prompts) == 2 # ───────────────────────────────────────────── # Device Status Tools # ───────────────────────────────────────────── async def test_get_device_status(ctx): result = await get_device_status(ctx) assert "3.05.0-test" in result["firmware"]["version"] assert result["usb_speed"] == "High (480 Mbps)" assert "de ad be ef" in result["serial"] async def test_get_signal_quality(ctx): result = await get_signal_quality(ctx) assert result["snr_db"] == 8.5 assert result["locked"] is True assert result["agc1"] == 1200 async def test_get_stream_diagnostics(ctx): result = await get_stream_diagnostics(ctx) assert result["poll_count"] == 100 assert result["overflow_count"] == 0 # ───────────────────────────────────────────── # Spectrum Sweep Validation # ───────────────────────────────────────────── async def test_sweep_defaults(ctx): result = await sweep_spectrum(ctx) assert result["start_mhz"] == 950.0 assert result["stop_mhz"] == 2150.0 assert result["num_points"] > 0 assert len(result["frequencies_mhz"]) == result["num_points"] assert len(result["powers_db"]) == result["num_points"] async def test_sweep_narrow_band(ctx): result = await sweep_spectrum(ctx, start_mhz=1418.0, stop_mhz=1423.0, step_mhz=0.5) assert result["num_points"] == 11 assert result["step_mhz"] == 0.5 async def test_sweep_freq_below_range(ctx): result = await sweep_spectrum(ctx, start_mhz=800.0) assert "error" in result assert "950" in result["error"] async def test_sweep_freq_above_range(ctx): result = await sweep_spectrum(ctx, stop_mhz=3000.0) assert "error" in result assert "2150" in result["error"] async def test_sweep_start_gt_stop(ctx): result = await sweep_spectrum(ctx, start_mhz=1500.0, stop_mhz=1000.0) assert "error" in result assert "less than" in result["error"] async def test_sweep_bad_step(ctx): result = await sweep_spectrum(ctx, step_mhz=0.01) assert "error" in result assert "step_mhz" in result["error"] async def test_sweep_step_too_large(ctx): result = await sweep_spectrum(ctx, step_mhz=200.0) assert "error" in result async def test_sweep_bad_dwell(ctx): result = await sweep_spectrum(ctx, dwell_ms=0) assert "error" in result assert "dwell_ms" in result["error"] async def test_sweep_dwell_too_high(ctx): result = await sweep_spectrum(ctx, dwell_ms=300) assert "error" in result # ───────────────────────────────────────────── # Tune Frequency Validation # ───────────────────────────────────────────── async def test_tune_valid(ctx): result = await tune_frequency(ctx, freq_mhz=1420.0, symbol_rate_ksps=5000) assert result["locked"] is True assert result["freq_mhz"] == 1420.0 assert result["modulation"] == "qpsk" async def test_tune_below_range(ctx): result = await tune_frequency(ctx, freq_mhz=500.0) assert "error" in result assert "950" in result["error"] async def test_tune_above_range(ctx): result = await tune_frequency(ctx, freq_mhz=2200.0) assert "error" in result async def test_tune_bad_sr_low(ctx): result = await tune_frequency(ctx, freq_mhz=1200.0, symbol_rate_ksps=100) assert "error" in result assert "256" in result["error"] async def test_tune_bad_sr_high(ctx): result = await tune_frequency(ctx, freq_mhz=1200.0, symbol_rate_ksps=50000) assert "error" in result assert "30000" in result["error"] async def test_tune_bad_modulation(ctx): result = await tune_frequency(ctx, freq_mhz=1200.0, modulation="dvb-s2") assert "error" in result assert "dvb-s2" in result["error"] async def test_tune_bad_dwell(ctx): result = await tune_frequency(ctx, freq_mhz=1200.0, dwell_ms=0) assert "error" in result # ───────────────────────────────────────────── # Blind Scan Validation # ───────────────────────────────────────────── async def test_blind_scan_valid(ctx): result = await run_blind_scan(ctx, freq_mhz=1200.0) assert result["locked"] is True assert result["sr_ksps"] == 20000.0 async def test_blind_scan_freq_below(ctx): result = await run_blind_scan(ctx, freq_mhz=500.0) assert "error" in result async def test_blind_scan_sr_below_min(ctx): result = await run_blind_scan(ctx, freq_mhz=1200.0, sr_min_ksps=100) assert "error" in result assert "256" in result["error"] async def test_blind_scan_sr_above_max(ctx): result = await run_blind_scan(ctx, freq_mhz=1200.0, sr_max_ksps=50000) assert "error" in result assert "30000" in result["error"] # ───────────────────────────────────────────── # Motor Safety Tests # ───────────────────────────────────────────── async def test_motor_halt(ctx, mock_device): result = await move_dish(ctx, action="halt") assert result["action"] == "halt" assert result["status"] == "stopped" assert mock_device._motor_halted is True async def test_motor_east_stepped(ctx): result = await move_dish(ctx, action="east", value=10) assert result["steps"] == 10 assert result["mode"] == "stepped" async def test_motor_west_stepped(ctx): result = await move_dish(ctx, action="west", value=5) assert result["steps"] == 5 assert result["action"] == "west" async def test_motor_continuous_rejected(ctx): """Continuous drive (steps=0) without explicit flag is rejected.""" result = await move_dish(ctx, action="east", value=0) assert "error" in result assert "CONTINUOUS" in result["error"] assert "continuous=True" in result["error"] async def test_motor_continuous_with_flag(ctx): """Continuous drive with explicit flag succeeds and starts watchdog.""" result = await move_dish(ctx, action="west", value=0, continuous=True) assert result["status"] == "driving" assert result["continuous"] is True assert result["watchdog_secs"] == MOTOR_WATCHDOG_SECS assert "warning" in result async def test_motor_steps_negative(ctx): result = await move_dish(ctx, action="east", value=-5) assert "error" in result assert "0-127" in result["error"] async def test_motor_steps_too_high(ctx): result = await move_dish(ctx, action="east", value=200) assert "error" in result assert "0-127" in result["error"] async def test_motor_gotox(ctx, mock_device): result = await move_dish(ctx, action="gotox", value=-97.0, observer_lon=-96.8) assert result["action"] == "gotox" assert result["satellite_lon"] == -97.0 assert "motor_angle_deg" in result assert ("motor_goto_x", (-96.8, -97.0), {}) in mock_device._calls async def test_motor_goto_slot(ctx, mock_device): result = await move_dish(ctx, action="goto", value=5) assert result["slot"] == 5 assert result["action"] == "goto" async def test_motor_goto_slot_out_of_range(ctx): result = await move_dish(ctx, action="goto", value=300) assert "error" in result assert "0-255" in result["error"] async def test_motor_invalid_action(ctx): result = await move_dish(ctx, action="spin") assert "error" in result assert "spin" in result["error"] # ───────────────────────────────────────────── # Motor Watchdog Tests # ───────────────────────────────────────────── async def test_watchdog_starts_on_continuous(ctx, bridge): """Watchdog task is created when continuous drive starts.""" await move_dish(ctx, action="west", value=0, continuous=True) assert bridge._motor_watchdog is not None assert not bridge._motor_watchdog.done() bridge.cancel_motor_watchdog() async def test_watchdog_cancelled_on_halt(ctx, bridge): """Halt cancels the watchdog.""" await move_dish(ctx, action="east", value=0, continuous=True) assert bridge._motor_watchdog is not None await move_dish(ctx, action="halt") assert bridge._motor_watchdog is None or bridge._motor_watchdog.cancelled() async def test_watchdog_fires_and_halts(ctx, bridge, mock_device): """Watchdog auto-halts after timeout.""" bridge.start_motor_watchdog(timeout=0.1) # 100ms for test speed await asyncio.sleep(0.3) # Wait for watchdog to fire assert mock_device._motor_halted is True # ───────────────────────────────────────────── # Jog Dish Tests # ───────────────────────────────────────────── async def test_jog_valid(ctx): result = await jog_dish(ctx, direction="east", steps=5) assert result["direction"] == "east" assert result["steps"] == 5 assert "snr_db" in result async def test_jog_too_many_steps(ctx): result = await jog_dish(ctx, direction="east", steps=50) assert "error" in result assert "1-30" in result["error"] async def test_jog_zero_steps(ctx): result = await jog_dish(ctx, direction="east", steps=0) assert "error" in result async def test_jog_bad_direction(ctx): result = await jog_dish(ctx, direction="up") assert "error" in result assert "east" in result["error"] # ───────────────────────────────────────────── # Store Position Tests # ───────────────────────────────────────────── async def test_store_position_valid(ctx, mock_device): result = await store_position(ctx, slot=5) assert result["stored"] is True assert result["slot"] == 5 async def test_store_position_slot_zero(ctx): result = await store_position(ctx, slot=0) assert "error" in result async def test_store_position_slot_too_high(ctx): result = await store_position(ctx, slot=300) assert "error" in result # ───────────────────────────────────────────── # LNB & I2C Tests # ───────────────────────────────────────────── async def test_lnb_disable(ctx, mock_device): result = await set_lnb_config(ctx, disable_lnb=True) assert result["lnb_power"] == "off" assert mock_device._lnb_on is False async def test_lnb_voltage(ctx, mock_device): result = await set_lnb_config(ctx, voltage="18V") assert result["voltage"] == "18V" async def test_lnb_tone(ctx, mock_device): result = await set_lnb_config(ctx, tone_22khz=True) assert result["tone_22khz"] is True async def test_i2c_scan(ctx): result = await scan_i2c_bus(ctx) assert result["device_count"] == 3 addresses = [d["address"] for d in result["devices"]] assert "0x08" in addresses assert "0x61" in addresses assert "0x51" in addresses async def test_i2c_read(ctx): result = await read_i2c_register(ctx, slave_address=0x08, register=0x00) assert result["value"] == 0xAB assert result["hex"] == "0xAB" assert "0b" in result["binary"] # ───────────────────────────────────────────── # Transport Stream Validation # ───────────────────────────────────────────── async def test_ts_capture_duration_too_short(ctx): result = await capture_transport_stream(ctx, duration_secs=0.1) assert "error" in result assert "0.5-30" in result["error"] async def test_ts_capture_duration_too_long(ctx): result = await capture_transport_stream(ctx, duration_secs=60.0) assert "error" in result async def test_ts_capture_valid(ctx): """Valid TS capture returns packet count (mock device is locked).""" result = await capture_transport_stream(ctx, duration_secs=0.5) assert result["bytes_captured"] > 0 assert result["packets"] > 0 # ───────────────────────────────────────────── # Frequency Identification # ───────────────────────────────────────────── async def test_identify_hydrogen(ctx): result = await identify_frequency(ctx, freq_mhz=1420.405) assert result["in_if_range"] is True signals = [m.get("signal", "") for m in result["matches"]] assert any("Hydrogen" in s for s in signals) async def test_identify_gps_l1(ctx): result = await identify_frequency(ctx, freq_mhz=1575.42) signals = [m.get("signal", "") for m in result["matches"]] assert any("GPS L1" in s for s in signals) async def test_identify_gps_l5(ctx): result = await identify_frequency(ctx, freq_mhz=1176.45) signals = [m.get("signal", "") for m in result["matches"]] assert any("GPS L5" in s or "Galileo E5a" in s for s in signals) async def test_identify_with_lnb(ctx): result = await identify_frequency(ctx, freq_mhz=1200.0, lnb_lo_mhz=9750.0) assert result["rf_freq_mhz"] == 10950.0 assert result["lnb_lo_mhz"] == 9750.0 async def test_identify_no_lnb(ctx): result = await identify_frequency(ctx, freq_mhz=1200.0) assert result["rf_freq_mhz"] is None assert result["lnb_lo_mhz"] is None # ───────────────────────────────────────────── # Path Traversal Protection # ───────────────────────────────────────────── async def test_compare_path_traversal(ctx): result = await compare_surveys(ctx, old_filename="../../../etc/passwd", new_filename="ok.json") assert "error" in result assert "plain filename" in result["error"] async def test_compare_dotdot_in_name(ctx): result = await compare_surveys(ctx, old_filename="..hidden.json", new_filename="ok.json") assert "error" in result async def test_compare_slash_in_name(ctx): result = await compare_surveys(ctx, old_filename="subdir/file.json", new_filename="ok.json") assert "error" in result assert "plain filename" in result["error"]