kicad-mcp/tests/test_autowire.py
Ryan Malloy 1a3ffb42cd
Some checks are pending
CI / Security Scan (push) Waiting to run
CI / Build Package (push) Blocked by required conditions
CI / Lint and Format (push) Waiting to run
CI / Test Python 3.11 on macos-latest (push) Waiting to run
CI / Test Python 3.12 on macos-latest (push) Waiting to run
CI / Test Python 3.13 on macos-latest (push) Waiting to run
CI / Test Python 3.10 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.11 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.12 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.13 on ubuntu-latest (push) Waiting to run
Add autowire_schematic tool for automated net wiring strategy
Classifies unconnected nets into optimal wiring methods (direct wire,
local/global label, power symbol, no-connect) based on pin distance,
crossing count, fanout, and net name patterns. Delegates all file
manipulation to apply_batch, inheriting collision detection and label
serialization for free.

Strategy concepts informed by KICAD-autowire (MIT, arashmparsa).
2026-03-08 18:27:06 -06:00

392 lines
14 KiB
Python

"""Tests for the autowire strategy, planner, and MCP tool."""
import json
import os
import pytest
from mckicad.autowire.planner import generate_batch_plan
from mckicad.autowire.strategy import (
NetPlan,
PinInfo,
WireSegment,
WiringMethod,
WiringThresholds,
classify_all_nets,
classify_net,
estimate_crossings,
pin_distance,
)
from tests.conftest import requires_sch_api
# ---------------------------------------------------------------------------
# Strategy unit tests
# ---------------------------------------------------------------------------
class TestPinDistance:
def test_same_point(self):
a = PinInfo(reference="R1", pin_number="1", x=100, y=100)
assert pin_distance(a, a) == 0.0
def test_horizontal(self):
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
b = PinInfo(reference="R2", pin_number="1", x=30, y=0)
assert pin_distance(a, b) == pytest.approx(30.0)
def test_diagonal(self):
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
b = PinInfo(reference="R2", pin_number="1", x=3, y=4)
assert pin_distance(a, b) == pytest.approx(5.0)
class TestCrossingEstimation:
def test_perpendicular_cross(self):
"""A horizontal wire crossing a vertical wire should count as 1."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=50, y1=0, x2=50, y2=100)]
assert estimate_crossings(proposed, existing) == 1
def test_parallel_no_cross(self):
"""Two parallel horizontal wires don't cross."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=0, y1=60, x2=100, y2=60)]
assert estimate_crossings(proposed, existing) == 0
def test_empty_existing(self):
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
assert estimate_crossings(proposed, []) == 0
def test_multiple_crossings(self):
"""A horizontal wire crossing three vertical wires."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [
WireSegment(x1=20, y1=0, x2=20, y2=100),
WireSegment(x1=50, y1=0, x2=50, y2=100),
WireSegment(x1=80, y1=0, x2=80, y2=100),
]
assert estimate_crossings(proposed, existing) == 3
def test_touching_endpoint_no_cross(self):
"""Wires that share an endpoint don't cross (strict inequality)."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=100, y1=0, x2=100, y2=100)]
assert estimate_crossings(proposed, existing) == 0
def test_collinear_no_cross(self):
"""Two collinear segments (same axis) don't cross."""
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
existing = [WireSegment(x1=50, y1=0, x2=150, y2=0)]
assert estimate_crossings(proposed, existing) == 0
class TestWiringStrategy:
"""Test the classify_net decision tree."""
def _make_pins(self, count, spacing=5.0, pin_type="passive"):
return [
PinInfo(
reference=f"R{i+1}",
pin_number="1",
x=i * spacing,
y=0,
pin_type=pin_type,
)
for i in range(count)
]
def test_power_net_by_name(self):
pins = self._make_pins(2)
plan = classify_net("GND", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
assert plan.power_lib_id is not None
def test_power_net_vcc(self):
pins = self._make_pins(2)
plan = classify_net("VCC", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_power_net_plus_3v3(self):
pins = self._make_pins(2)
plan = classify_net("+3V3", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_power_net_by_pin_type(self):
pins = [
PinInfo(reference="U1", pin_number="1", x=0, y=0, pin_type="power_in"),
PinInfo(reference="C1", pin_number="2", x=10, y=0, pin_type="passive"),
]
plan = classify_net("CUSTOM_RAIL", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_single_pin_no_connect(self):
pins = self._make_pins(1)
plan = classify_net("NC_PIN", pins, WiringThresholds())
assert plan.method == WiringMethod.NO_CONNECT
def test_cross_sheet_global_label(self):
pins = self._make_pins(2)
plan = classify_net("SPI_CLK", pins, WiringThresholds(), is_cross_sheet=True)
assert plan.method == WiringMethod.GLOBAL_LABEL
def test_high_fanout_global_label(self):
pins = self._make_pins(6, spacing=10)
plan = classify_net("DATA_BUS", pins, WiringThresholds(high_fanout_threshold=5))
assert plan.method == WiringMethod.GLOBAL_LABEL
assert "fanout" in plan.reason
def test_short_distance_direct_wire(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=5, y=0),
]
plan = classify_net("NET1", pins, WiringThresholds())
assert plan.method == WiringMethod.DIRECT_WIRE
assert "short" in plan.reason
def test_long_distance_local_label(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=100, y=0),
]
plan = classify_net("NET2", pins, WiringThresholds(direct_wire_max_distance=50))
assert plan.method == WiringMethod.LOCAL_LABEL
assert "long" in plan.reason
def test_mid_range_no_crossings_direct_wire(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=30, y=0),
]
plan = classify_net("NET3", pins, WiringThresholds(), existing_wires=[])
assert plan.method == WiringMethod.DIRECT_WIRE
def test_mid_range_many_crossings_label(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=0, y=30),
]
# Vertical wire from (0,0) to (0,30), crossing 3 horizontal wires
existing = [
WireSegment(x1=-10, y1=10, x2=10, y2=10),
WireSegment(x1=-10, y1=15, x2=10, y2=15),
WireSegment(x1=-10, y1=20, x2=10, y2=20),
]
plan = classify_net(
"NET4", pins,
WiringThresholds(crossing_threshold=2),
existing_wires=existing,
)
assert plan.method == WiringMethod.LOCAL_LABEL
assert "crossing" in plan.reason
def test_three_pin_net_local_label(self):
pins = self._make_pins(3, spacing=10)
plan = classify_net("NET5", pins, WiringThresholds())
assert plan.method == WiringMethod.LOCAL_LABEL
assert "multi-pin" in plan.reason
def test_custom_thresholds(self):
"""Custom thresholds override defaults."""
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=20, y=0),
]
# With default threshold (50mm), 20mm is mid-range -> direct wire
plan1 = classify_net("NET", pins, WiringThresholds())
assert plan1.method == WiringMethod.DIRECT_WIRE
# With tight threshold (15mm), 20mm exceeds max -> label
plan2 = classify_net("NET", pins, WiringThresholds(direct_wire_max_distance=15))
assert plan2.method == WiringMethod.LOCAL_LABEL
class TestClassifyAllNets:
def test_basic_classification(self):
nets = {
"GND": [["U1", "1"], ["C1", "2"]],
"SIG1": [["U1", "2"], ["R1", "1"]],
}
pin_positions = {
"U1": {"1": (100, 100), "2": (100, 110)},
"C1": {"2": (105, 100)},
"R1": {"1": (105, 110)},
}
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
methods = {p.net_name: p.method for p in plans}
assert methods["GND"] == WiringMethod.POWER_SYMBOL
assert methods["SIG1"] == WiringMethod.DIRECT_WIRE
def test_exclude_nets(self):
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
exclude_nets={"GND"},
)
assert all(p.net_name != "GND" for p in plans)
def test_only_nets(self):
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
only_nets={"SIG"},
)
assert len(plans) == 1
assert plans[0].net_name == "SIG"
def test_missing_pin_positions_skipped(self):
"""Nets with pins not in pin_positions are skipped."""
nets = {"SIG": [["U99", "1"], ["U99", "2"]]}
pin_positions: dict = {}
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
assert len(plans) == 0
def test_pin_metadata_power_detection(self):
"""Pin metadata with power_in type triggers POWER_SYMBOL."""
nets = {"CUSTOM_PWR": [["U1", "1"], ["C1", "1"]]}
pin_positions = {"U1": {"1": (0, 0)}, "C1": {"1": (5, 0)}}
pin_metadata = {"U1": {"1": {"pintype": "power_in", "pinfunction": "VCC"}}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
pin_metadata=pin_metadata,
)
assert plans[0].method == WiringMethod.POWER_SYMBOL
# ---------------------------------------------------------------------------
# Planner unit tests
# ---------------------------------------------------------------------------
class TestBatchPlanGeneration:
def _make_plan(self, method, net="NET1", pin_count=2, **kwargs):
pins = [
PinInfo(reference=f"R{i+1}", pin_number=str(i+1), x=i*10, y=0)
for i in range(pin_count)
]
return NetPlan(net_name=net, method=method, pins=pins, **kwargs)
def test_direct_wire(self):
plan = self._make_plan(WiringMethod.DIRECT_WIRE)
batch = generate_batch_plan([plan])
assert "wires" in batch
assert len(batch["wires"]) == 1
wire = batch["wires"][0]
assert wire["from_ref"] == "R1"
assert wire["to_ref"] == "R2"
def test_local_label(self):
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="SIG1")
batch = generate_batch_plan([plan])
assert "label_connections" in batch
lc = batch["label_connections"][0]
assert lc["net"] == "SIG1"
assert lc["global"] is False
assert len(lc["connections"]) == 2
def test_global_label(self):
plan = self._make_plan(
WiringMethod.GLOBAL_LABEL, net="SPI_CLK",
label_shape="input",
)
batch = generate_batch_plan([plan])
lc = batch["label_connections"][0]
assert lc["global"] is True
assert lc["shape"] == "input"
def test_power_symbol(self):
plan = self._make_plan(
WiringMethod.POWER_SYMBOL, net="GND",
power_lib_id="power:GND",
)
batch = generate_batch_plan([plan])
assert "power_symbols" in batch
assert len(batch["power_symbols"]) == 2 # one per pin
assert batch["power_symbols"][0]["net"] == "GND"
assert batch["power_symbols"][0]["lib_id"] == "power:GND"
def test_no_connect(self):
plan = self._make_plan(WiringMethod.NO_CONNECT, pin_count=1)
batch = generate_batch_plan([plan])
assert "no_connects" in batch
assert len(batch["no_connects"]) == 1
def test_skip_produces_nothing(self):
plan = self._make_plan(WiringMethod.SKIP)
batch = generate_batch_plan([plan])
assert batch == {}
def test_mixed_plan(self):
plans = [
self._make_plan(WiringMethod.DIRECT_WIRE, net="NET1"),
self._make_plan(WiringMethod.LOCAL_LABEL, net="NET2"),
self._make_plan(WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND"),
self._make_plan(WiringMethod.NO_CONNECT, net="NC1", pin_count=1),
]
batch = generate_batch_plan(plans)
assert "wires" in batch
assert "label_connections" in batch
assert "power_symbols" in batch
assert "no_connects" in batch
def test_existing_labels_skipped(self):
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="ALREADY_PLACED")
batch = generate_batch_plan([plan], existing_labels={"ALREADY_PLACED"})
assert "label_connections" not in batch
# ---------------------------------------------------------------------------
# Integration tests (require kicad-sch-api)
# ---------------------------------------------------------------------------
class TestAutowireTool:
@requires_sch_api
def test_dry_run_returns_plan(self, populated_schematic):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic(populated_schematic, dry_run=True)
assert result["success"] is True
assert result["dry_run"] is True
assert "strategy_summary" in result
assert "plan" in result
assert "batch_file" in result
# Verify batch file was written
assert os.path.isfile(result["batch_file"])
with open(result["batch_file"]) as f:
batch_data = json.load(f)
assert isinstance(batch_data, dict)
@requires_sch_api
def test_dry_run_does_not_modify(self, populated_schematic):
"""dry_run=True should not change the schematic file."""
with open(populated_schematic) as f:
original_content = f.read()
from mckicad.tools.autowire import autowire_schematic
autowire_schematic(populated_schematic, dry_run=True)
with open(populated_schematic) as f:
after_content = f.read()
assert original_content == after_content
@requires_sch_api
def test_invalid_schematic_path(self):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic("/nonexistent/path.kicad_sch")
assert result["success"] is False
@requires_sch_api
def test_invalid_extension(self):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic("/some/file.txt")
assert result["success"] is False