"""Tests for the autowire strategy, planner, and MCP tool.""" import json import os import pytest from mckicad.autowire.planner import generate_batch_plan from mckicad.autowire.strategy import ( NetPlan, PinInfo, WireSegment, WiringMethod, WiringThresholds, classify_all_nets, classify_net, estimate_crossings, pin_distance, ) from tests.conftest import requires_sch_api # --------------------------------------------------------------------------- # Strategy unit tests # --------------------------------------------------------------------------- class TestPinDistance: def test_same_point(self): a = PinInfo(reference="R1", pin_number="1", x=100, y=100) assert pin_distance(a, a) == 0.0 def test_horizontal(self): a = PinInfo(reference="R1", pin_number="1", x=0, y=0) b = PinInfo(reference="R2", pin_number="1", x=30, y=0) assert pin_distance(a, b) == pytest.approx(30.0) def test_diagonal(self): a = PinInfo(reference="R1", pin_number="1", x=0, y=0) b = PinInfo(reference="R2", pin_number="1", x=3, y=4) assert pin_distance(a, b) == pytest.approx(5.0) class TestCrossingEstimation: def test_perpendicular_cross(self): """A horizontal wire crossing a vertical wire should count as 1.""" proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) existing = [WireSegment(x1=50, y1=0, x2=50, y2=100)] assert estimate_crossings(proposed, existing) == 1 def test_parallel_no_cross(self): """Two parallel horizontal wires don't cross.""" proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) existing = [WireSegment(x1=0, y1=60, x2=100, y2=60)] assert estimate_crossings(proposed, existing) == 0 def test_empty_existing(self): proposed = WireSegment(x1=0, y1=0, x2=100, y2=0) assert estimate_crossings(proposed, []) == 0 def test_multiple_crossings(self): """A horizontal wire crossing three vertical wires.""" proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) existing = [ WireSegment(x1=20, y1=0, x2=20, y2=100), WireSegment(x1=50, y1=0, x2=50, y2=100), WireSegment(x1=80, y1=0, x2=80, y2=100), ] assert estimate_crossings(proposed, existing) == 3 def test_touching_endpoint_no_cross(self): """Wires that share an endpoint don't cross (strict inequality).""" proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) existing = [WireSegment(x1=100, y1=0, x2=100, y2=100)] assert estimate_crossings(proposed, existing) == 0 def test_collinear_no_cross(self): """Two collinear segments (same axis) don't cross.""" proposed = WireSegment(x1=0, y1=0, x2=100, y2=0) existing = [WireSegment(x1=50, y1=0, x2=150, y2=0)] assert estimate_crossings(proposed, existing) == 0 class TestWiringStrategy: """Test the classify_net decision tree.""" def _make_pins(self, count, spacing=5.0, pin_type="passive"): return [ PinInfo( reference=f"R{i+1}", pin_number="1", x=i * spacing, y=0, pin_type=pin_type, ) for i in range(count) ] def test_power_net_by_name(self): pins = self._make_pins(2) plan = classify_net("GND", pins, WiringThresholds()) assert plan.method == WiringMethod.POWER_SYMBOL assert plan.power_lib_id is not None def test_power_net_vcc(self): pins = self._make_pins(2) plan = classify_net("VCC", pins, WiringThresholds()) assert plan.method == WiringMethod.POWER_SYMBOL def test_power_net_plus_3v3(self): pins = self._make_pins(2) plan = classify_net("+3V3", pins, WiringThresholds()) assert plan.method == WiringMethod.POWER_SYMBOL def test_power_net_by_pin_type(self): pins = [ PinInfo(reference="U1", pin_number="1", x=0, y=0, pin_type="power_in"), PinInfo(reference="C1", pin_number="2", x=10, y=0, pin_type="passive"), ] plan = classify_net("CUSTOM_RAIL", pins, WiringThresholds()) assert plan.method == WiringMethod.POWER_SYMBOL def test_single_pin_no_connect(self): pins = self._make_pins(1) plan = classify_net("NC_PIN", pins, WiringThresholds()) assert plan.method == WiringMethod.NO_CONNECT def test_cross_sheet_global_label(self): pins = self._make_pins(2) plan = classify_net("SPI_CLK", pins, WiringThresholds(), is_cross_sheet=True) assert plan.method == WiringMethod.GLOBAL_LABEL def test_high_fanout_global_label(self): pins = self._make_pins(6, spacing=10) plan = classify_net("DATA_BUS", pins, WiringThresholds(high_fanout_threshold=5)) assert plan.method == WiringMethod.GLOBAL_LABEL assert "fanout" in plan.reason def test_short_distance_direct_wire(self): pins = [ PinInfo(reference="R1", pin_number="1", x=0, y=0), PinInfo(reference="R2", pin_number="1", x=5, y=0), ] plan = classify_net("NET1", pins, WiringThresholds()) assert plan.method == WiringMethod.DIRECT_WIRE assert "short" in plan.reason def test_long_distance_local_label(self): pins = [ PinInfo(reference="R1", pin_number="1", x=0, y=0), PinInfo(reference="R2", pin_number="1", x=100, y=0), ] plan = classify_net("NET2", pins, WiringThresholds(direct_wire_max_distance=50)) assert plan.method == WiringMethod.LOCAL_LABEL assert "long" in plan.reason def test_mid_range_no_crossings_direct_wire(self): pins = [ PinInfo(reference="R1", pin_number="1", x=0, y=0), PinInfo(reference="R2", pin_number="1", x=30, y=0), ] plan = classify_net("NET3", pins, WiringThresholds(), existing_wires=[]) assert plan.method == WiringMethod.DIRECT_WIRE def test_mid_range_many_crossings_label(self): pins = [ PinInfo(reference="R1", pin_number="1", x=0, y=0), PinInfo(reference="R2", pin_number="1", x=0, y=30), ] # Vertical wire from (0,0) to (0,30), crossing 3 horizontal wires existing = [ WireSegment(x1=-10, y1=10, x2=10, y2=10), WireSegment(x1=-10, y1=15, x2=10, y2=15), WireSegment(x1=-10, y1=20, x2=10, y2=20), ] plan = classify_net( "NET4", pins, WiringThresholds(crossing_threshold=2), existing_wires=existing, ) assert plan.method == WiringMethod.LOCAL_LABEL assert "crossing" in plan.reason def test_three_pin_net_local_label(self): pins = self._make_pins(3, spacing=10) plan = classify_net("NET5", pins, WiringThresholds()) assert plan.method == WiringMethod.LOCAL_LABEL assert "multi-pin" in plan.reason def test_custom_thresholds(self): """Custom thresholds override defaults.""" pins = [ PinInfo(reference="R1", pin_number="1", x=0, y=0), PinInfo(reference="R2", pin_number="1", x=20, y=0), ] # With default threshold (50mm), 20mm is mid-range -> direct wire plan1 = classify_net("NET", pins, WiringThresholds()) assert plan1.method == WiringMethod.DIRECT_WIRE # With tight threshold (15mm), 20mm exceeds max -> label plan2 = classify_net("NET", pins, WiringThresholds(direct_wire_max_distance=15)) assert plan2.method == WiringMethod.LOCAL_LABEL class TestClassifyAllNets: def test_basic_classification(self): nets = { "GND": [["U1", "1"], ["C1", "2"]], "SIG1": [["U1", "2"], ["R1", "1"]], } pin_positions = { "U1": {"1": (100, 100), "2": (100, 110)}, "C1": {"2": (105, 100)}, "R1": {"1": (105, 110)}, } plans = classify_all_nets(nets, pin_positions, WiringThresholds()) methods = {p.net_name: p.method for p in plans} assert methods["GND"] == WiringMethod.POWER_SYMBOL assert methods["SIG1"] == WiringMethod.DIRECT_WIRE def test_exclude_nets(self): nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]} pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}} plans = classify_all_nets( nets, pin_positions, WiringThresholds(), exclude_nets={"GND"}, ) assert all(p.net_name != "GND" for p in plans) def test_only_nets(self): nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]} pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}} plans = classify_all_nets( nets, pin_positions, WiringThresholds(), only_nets={"SIG"}, ) assert len(plans) == 1 assert plans[0].net_name == "SIG" def test_missing_pin_positions_skipped(self): """Nets with pins not in pin_positions are skipped.""" nets = {"SIG": [["U99", "1"], ["U99", "2"]]} pin_positions: dict = {} plans = classify_all_nets(nets, pin_positions, WiringThresholds()) assert len(plans) == 0 def test_pin_metadata_power_detection(self): """Pin metadata with power_in type triggers POWER_SYMBOL.""" nets = {"CUSTOM_PWR": [["U1", "1"], ["C1", "1"]]} pin_positions = {"U1": {"1": (0, 0)}, "C1": {"1": (5, 0)}} pin_metadata = {"U1": {"1": {"pintype": "power_in", "pinfunction": "VCC"}}} plans = classify_all_nets( nets, pin_positions, WiringThresholds(), pin_metadata=pin_metadata, ) assert plans[0].method == WiringMethod.POWER_SYMBOL # --------------------------------------------------------------------------- # Planner unit tests # --------------------------------------------------------------------------- class TestBatchPlanGeneration: def _make_plan(self, method, net="NET1", pin_count=2, **kwargs): pins = [ PinInfo(reference=f"R{i+1}", pin_number=str(i+1), x=i*10, y=0) for i in range(pin_count) ] return NetPlan(net_name=net, method=method, pins=pins, **kwargs) def test_direct_wire(self): plan = self._make_plan(WiringMethod.DIRECT_WIRE) batch = generate_batch_plan([plan]) assert "wires" in batch assert len(batch["wires"]) == 1 wire = batch["wires"][0] assert wire["from_ref"] == "R1" assert wire["to_ref"] == "R2" def test_local_label(self): plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="SIG1") batch = generate_batch_plan([plan]) assert "label_connections" in batch lc = batch["label_connections"][0] assert lc["net"] == "SIG1" assert lc["global"] is False assert len(lc["connections"]) == 2 def test_global_label(self): plan = self._make_plan( WiringMethod.GLOBAL_LABEL, net="SPI_CLK", label_shape="input", ) batch = generate_batch_plan([plan]) lc = batch["label_connections"][0] assert lc["global"] is True assert lc["shape"] == "input" def test_power_symbol(self): plan = self._make_plan( WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND", ) batch = generate_batch_plan([plan]) assert "power_symbols" in batch assert len(batch["power_symbols"]) == 2 # one per pin assert batch["power_symbols"][0]["net"] == "GND" assert batch["power_symbols"][0]["lib_id"] == "power:GND" def test_no_connect(self): plan = self._make_plan(WiringMethod.NO_CONNECT, pin_count=1) batch = generate_batch_plan([plan]) assert "no_connects" in batch assert len(batch["no_connects"]) == 1 def test_skip_produces_nothing(self): plan = self._make_plan(WiringMethod.SKIP) batch = generate_batch_plan([plan]) assert batch == {} def test_mixed_plan(self): plans = [ self._make_plan(WiringMethod.DIRECT_WIRE, net="NET1"), self._make_plan(WiringMethod.LOCAL_LABEL, net="NET2"), self._make_plan(WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND"), self._make_plan(WiringMethod.NO_CONNECT, net="NC1", pin_count=1), ] batch = generate_batch_plan(plans) assert "wires" in batch assert "label_connections" in batch assert "power_symbols" in batch assert "no_connects" in batch def test_existing_labels_skipped(self): plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="ALREADY_PLACED") batch = generate_batch_plan([plan], existing_labels={"ALREADY_PLACED"}) assert "label_connections" not in batch # --------------------------------------------------------------------------- # Integration tests (require kicad-sch-api) # --------------------------------------------------------------------------- class TestAutowireTool: @requires_sch_api def test_dry_run_returns_plan(self, populated_schematic): from mckicad.tools.autowire import autowire_schematic result = autowire_schematic(populated_schematic, dry_run=True) assert result["success"] is True assert result["dry_run"] is True assert "strategy_summary" in result assert "plan" in result assert "batch_file" in result # Verify batch file was written assert os.path.isfile(result["batch_file"]) with open(result["batch_file"]) as f: batch_data = json.load(f) assert isinstance(batch_data, dict) @requires_sch_api def test_dry_run_does_not_modify(self, populated_schematic): """dry_run=True should not change the schematic file.""" with open(populated_schematic) as f: original_content = f.read() from mckicad.tools.autowire import autowire_schematic autowire_schematic(populated_schematic, dry_run=True) with open(populated_schematic) as f: after_content = f.read() assert original_content == after_content @requires_sch_api def test_invalid_schematic_path(self): from mckicad.tools.autowire import autowire_schematic result = autowire_schematic("/nonexistent/path.kicad_sch") assert result["success"] is False @requires_sch_api def test_invalid_extension(self): from mckicad.tools.autowire import autowire_schematic result = autowire_schematic("/some/file.txt") assert result["success"] is False