Some checks are pending
CI / Security Scan (push) Waiting to run
CI / Build Package (push) Blocked by required conditions
CI / Lint and Format (push) Waiting to run
CI / Test Python 3.11 on macos-latest (push) Waiting to run
CI / Test Python 3.12 on macos-latest (push) Waiting to run
CI / Test Python 3.13 on macos-latest (push) Waiting to run
CI / Test Python 3.10 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.11 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.12 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.13 on ubuntu-latest (push) Waiting to run
Classifies unconnected nets into optimal wiring methods (direct wire, local/global label, power symbol, no-connect) based on pin distance, crossing count, fanout, and net name patterns. Delegates all file manipulation to apply_batch, inheriting collision detection and label serialization for free. Strategy concepts informed by KICAD-autowire (MIT, arashmparsa).
392 lines
14 KiB
Python
392 lines
14 KiB
Python
"""Tests for the autowire strategy, planner, and MCP tool."""
|
|
|
|
import json
|
|
import os
|
|
|
|
import pytest
|
|
|
|
from mckicad.autowire.planner import generate_batch_plan
|
|
from mckicad.autowire.strategy import (
|
|
NetPlan,
|
|
PinInfo,
|
|
WireSegment,
|
|
WiringMethod,
|
|
WiringThresholds,
|
|
classify_all_nets,
|
|
classify_net,
|
|
estimate_crossings,
|
|
pin_distance,
|
|
)
|
|
from tests.conftest import requires_sch_api
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Strategy unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPinDistance:
|
|
def test_same_point(self):
|
|
a = PinInfo(reference="R1", pin_number="1", x=100, y=100)
|
|
assert pin_distance(a, a) == 0.0
|
|
|
|
def test_horizontal(self):
|
|
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
|
|
b = PinInfo(reference="R2", pin_number="1", x=30, y=0)
|
|
assert pin_distance(a, b) == pytest.approx(30.0)
|
|
|
|
def test_diagonal(self):
|
|
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
|
|
b = PinInfo(reference="R2", pin_number="1", x=3, y=4)
|
|
assert pin_distance(a, b) == pytest.approx(5.0)
|
|
|
|
|
|
class TestCrossingEstimation:
|
|
def test_perpendicular_cross(self):
|
|
"""A horizontal wire crossing a vertical wire should count as 1."""
|
|
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
|
|
existing = [WireSegment(x1=50, y1=0, x2=50, y2=100)]
|
|
assert estimate_crossings(proposed, existing) == 1
|
|
|
|
def test_parallel_no_cross(self):
|
|
"""Two parallel horizontal wires don't cross."""
|
|
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
|
|
existing = [WireSegment(x1=0, y1=60, x2=100, y2=60)]
|
|
assert estimate_crossings(proposed, existing) == 0
|
|
|
|
def test_empty_existing(self):
|
|
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
|
|
assert estimate_crossings(proposed, []) == 0
|
|
|
|
def test_multiple_crossings(self):
|
|
"""A horizontal wire crossing three vertical wires."""
|
|
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
|
|
existing = [
|
|
WireSegment(x1=20, y1=0, x2=20, y2=100),
|
|
WireSegment(x1=50, y1=0, x2=50, y2=100),
|
|
WireSegment(x1=80, y1=0, x2=80, y2=100),
|
|
]
|
|
assert estimate_crossings(proposed, existing) == 3
|
|
|
|
def test_touching_endpoint_no_cross(self):
|
|
"""Wires that share an endpoint don't cross (strict inequality)."""
|
|
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
|
|
existing = [WireSegment(x1=100, y1=0, x2=100, y2=100)]
|
|
assert estimate_crossings(proposed, existing) == 0
|
|
|
|
def test_collinear_no_cross(self):
|
|
"""Two collinear segments (same axis) don't cross."""
|
|
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
|
|
existing = [WireSegment(x1=50, y1=0, x2=150, y2=0)]
|
|
assert estimate_crossings(proposed, existing) == 0
|
|
|
|
|
|
class TestWiringStrategy:
|
|
"""Test the classify_net decision tree."""
|
|
|
|
def _make_pins(self, count, spacing=5.0, pin_type="passive"):
|
|
return [
|
|
PinInfo(
|
|
reference=f"R{i+1}",
|
|
pin_number="1",
|
|
x=i * spacing,
|
|
y=0,
|
|
pin_type=pin_type,
|
|
)
|
|
for i in range(count)
|
|
]
|
|
|
|
def test_power_net_by_name(self):
|
|
pins = self._make_pins(2)
|
|
plan = classify_net("GND", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.POWER_SYMBOL
|
|
assert plan.power_lib_id is not None
|
|
|
|
def test_power_net_vcc(self):
|
|
pins = self._make_pins(2)
|
|
plan = classify_net("VCC", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.POWER_SYMBOL
|
|
|
|
def test_power_net_plus_3v3(self):
|
|
pins = self._make_pins(2)
|
|
plan = classify_net("+3V3", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.POWER_SYMBOL
|
|
|
|
def test_power_net_by_pin_type(self):
|
|
pins = [
|
|
PinInfo(reference="U1", pin_number="1", x=0, y=0, pin_type="power_in"),
|
|
PinInfo(reference="C1", pin_number="2", x=10, y=0, pin_type="passive"),
|
|
]
|
|
plan = classify_net("CUSTOM_RAIL", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.POWER_SYMBOL
|
|
|
|
def test_single_pin_no_connect(self):
|
|
pins = self._make_pins(1)
|
|
plan = classify_net("NC_PIN", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.NO_CONNECT
|
|
|
|
def test_cross_sheet_global_label(self):
|
|
pins = self._make_pins(2)
|
|
plan = classify_net("SPI_CLK", pins, WiringThresholds(), is_cross_sheet=True)
|
|
assert plan.method == WiringMethod.GLOBAL_LABEL
|
|
|
|
def test_high_fanout_global_label(self):
|
|
pins = self._make_pins(6, spacing=10)
|
|
plan = classify_net("DATA_BUS", pins, WiringThresholds(high_fanout_threshold=5))
|
|
assert plan.method == WiringMethod.GLOBAL_LABEL
|
|
assert "fanout" in plan.reason
|
|
|
|
def test_short_distance_direct_wire(self):
|
|
pins = [
|
|
PinInfo(reference="R1", pin_number="1", x=0, y=0),
|
|
PinInfo(reference="R2", pin_number="1", x=5, y=0),
|
|
]
|
|
plan = classify_net("NET1", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.DIRECT_WIRE
|
|
assert "short" in plan.reason
|
|
|
|
def test_long_distance_local_label(self):
|
|
pins = [
|
|
PinInfo(reference="R1", pin_number="1", x=0, y=0),
|
|
PinInfo(reference="R2", pin_number="1", x=100, y=0),
|
|
]
|
|
plan = classify_net("NET2", pins, WiringThresholds(direct_wire_max_distance=50))
|
|
assert plan.method == WiringMethod.LOCAL_LABEL
|
|
assert "long" in plan.reason
|
|
|
|
def test_mid_range_no_crossings_direct_wire(self):
|
|
pins = [
|
|
PinInfo(reference="R1", pin_number="1", x=0, y=0),
|
|
PinInfo(reference="R2", pin_number="1", x=30, y=0),
|
|
]
|
|
plan = classify_net("NET3", pins, WiringThresholds(), existing_wires=[])
|
|
assert plan.method == WiringMethod.DIRECT_WIRE
|
|
|
|
def test_mid_range_many_crossings_label(self):
|
|
pins = [
|
|
PinInfo(reference="R1", pin_number="1", x=0, y=0),
|
|
PinInfo(reference="R2", pin_number="1", x=0, y=30),
|
|
]
|
|
# Vertical wire from (0,0) to (0,30), crossing 3 horizontal wires
|
|
existing = [
|
|
WireSegment(x1=-10, y1=10, x2=10, y2=10),
|
|
WireSegment(x1=-10, y1=15, x2=10, y2=15),
|
|
WireSegment(x1=-10, y1=20, x2=10, y2=20),
|
|
]
|
|
plan = classify_net(
|
|
"NET4", pins,
|
|
WiringThresholds(crossing_threshold=2),
|
|
existing_wires=existing,
|
|
)
|
|
assert plan.method == WiringMethod.LOCAL_LABEL
|
|
assert "crossing" in plan.reason
|
|
|
|
def test_three_pin_net_local_label(self):
|
|
pins = self._make_pins(3, spacing=10)
|
|
plan = classify_net("NET5", pins, WiringThresholds())
|
|
assert plan.method == WiringMethod.LOCAL_LABEL
|
|
assert "multi-pin" in plan.reason
|
|
|
|
def test_custom_thresholds(self):
|
|
"""Custom thresholds override defaults."""
|
|
pins = [
|
|
PinInfo(reference="R1", pin_number="1", x=0, y=0),
|
|
PinInfo(reference="R2", pin_number="1", x=20, y=0),
|
|
]
|
|
# With default threshold (50mm), 20mm is mid-range -> direct wire
|
|
plan1 = classify_net("NET", pins, WiringThresholds())
|
|
assert plan1.method == WiringMethod.DIRECT_WIRE
|
|
|
|
# With tight threshold (15mm), 20mm exceeds max -> label
|
|
plan2 = classify_net("NET", pins, WiringThresholds(direct_wire_max_distance=15))
|
|
assert plan2.method == WiringMethod.LOCAL_LABEL
|
|
|
|
|
|
class TestClassifyAllNets:
|
|
def test_basic_classification(self):
|
|
nets = {
|
|
"GND": [["U1", "1"], ["C1", "2"]],
|
|
"SIG1": [["U1", "2"], ["R1", "1"]],
|
|
}
|
|
pin_positions = {
|
|
"U1": {"1": (100, 100), "2": (100, 110)},
|
|
"C1": {"2": (105, 100)},
|
|
"R1": {"1": (105, 110)},
|
|
}
|
|
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
|
|
|
|
methods = {p.net_name: p.method for p in plans}
|
|
assert methods["GND"] == WiringMethod.POWER_SYMBOL
|
|
assert methods["SIG1"] == WiringMethod.DIRECT_WIRE
|
|
|
|
def test_exclude_nets(self):
|
|
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
|
|
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
|
|
plans = classify_all_nets(
|
|
nets, pin_positions, WiringThresholds(),
|
|
exclude_nets={"GND"},
|
|
)
|
|
assert all(p.net_name != "GND" for p in plans)
|
|
|
|
def test_only_nets(self):
|
|
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
|
|
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
|
|
plans = classify_all_nets(
|
|
nets, pin_positions, WiringThresholds(),
|
|
only_nets={"SIG"},
|
|
)
|
|
assert len(plans) == 1
|
|
assert plans[0].net_name == "SIG"
|
|
|
|
def test_missing_pin_positions_skipped(self):
|
|
"""Nets with pins not in pin_positions are skipped."""
|
|
nets = {"SIG": [["U99", "1"], ["U99", "2"]]}
|
|
pin_positions: dict = {}
|
|
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
|
|
assert len(plans) == 0
|
|
|
|
def test_pin_metadata_power_detection(self):
|
|
"""Pin metadata with power_in type triggers POWER_SYMBOL."""
|
|
nets = {"CUSTOM_PWR": [["U1", "1"], ["C1", "1"]]}
|
|
pin_positions = {"U1": {"1": (0, 0)}, "C1": {"1": (5, 0)}}
|
|
pin_metadata = {"U1": {"1": {"pintype": "power_in", "pinfunction": "VCC"}}}
|
|
plans = classify_all_nets(
|
|
nets, pin_positions, WiringThresholds(),
|
|
pin_metadata=pin_metadata,
|
|
)
|
|
assert plans[0].method == WiringMethod.POWER_SYMBOL
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Planner unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBatchPlanGeneration:
|
|
def _make_plan(self, method, net="NET1", pin_count=2, **kwargs):
|
|
pins = [
|
|
PinInfo(reference=f"R{i+1}", pin_number=str(i+1), x=i*10, y=0)
|
|
for i in range(pin_count)
|
|
]
|
|
return NetPlan(net_name=net, method=method, pins=pins, **kwargs)
|
|
|
|
def test_direct_wire(self):
|
|
plan = self._make_plan(WiringMethod.DIRECT_WIRE)
|
|
batch = generate_batch_plan([plan])
|
|
assert "wires" in batch
|
|
assert len(batch["wires"]) == 1
|
|
wire = batch["wires"][0]
|
|
assert wire["from_ref"] == "R1"
|
|
assert wire["to_ref"] == "R2"
|
|
|
|
def test_local_label(self):
|
|
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="SIG1")
|
|
batch = generate_batch_plan([plan])
|
|
assert "label_connections" in batch
|
|
lc = batch["label_connections"][0]
|
|
assert lc["net"] == "SIG1"
|
|
assert lc["global"] is False
|
|
assert len(lc["connections"]) == 2
|
|
|
|
def test_global_label(self):
|
|
plan = self._make_plan(
|
|
WiringMethod.GLOBAL_LABEL, net="SPI_CLK",
|
|
label_shape="input",
|
|
)
|
|
batch = generate_batch_plan([plan])
|
|
lc = batch["label_connections"][0]
|
|
assert lc["global"] is True
|
|
assert lc["shape"] == "input"
|
|
|
|
def test_power_symbol(self):
|
|
plan = self._make_plan(
|
|
WiringMethod.POWER_SYMBOL, net="GND",
|
|
power_lib_id="power:GND",
|
|
)
|
|
batch = generate_batch_plan([plan])
|
|
assert "power_symbols" in batch
|
|
assert len(batch["power_symbols"]) == 2 # one per pin
|
|
assert batch["power_symbols"][0]["net"] == "GND"
|
|
assert batch["power_symbols"][0]["lib_id"] == "power:GND"
|
|
|
|
def test_no_connect(self):
|
|
plan = self._make_plan(WiringMethod.NO_CONNECT, pin_count=1)
|
|
batch = generate_batch_plan([plan])
|
|
assert "no_connects" in batch
|
|
assert len(batch["no_connects"]) == 1
|
|
|
|
def test_skip_produces_nothing(self):
|
|
plan = self._make_plan(WiringMethod.SKIP)
|
|
batch = generate_batch_plan([plan])
|
|
assert batch == {}
|
|
|
|
def test_mixed_plan(self):
|
|
plans = [
|
|
self._make_plan(WiringMethod.DIRECT_WIRE, net="NET1"),
|
|
self._make_plan(WiringMethod.LOCAL_LABEL, net="NET2"),
|
|
self._make_plan(WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND"),
|
|
self._make_plan(WiringMethod.NO_CONNECT, net="NC1", pin_count=1),
|
|
]
|
|
batch = generate_batch_plan(plans)
|
|
assert "wires" in batch
|
|
assert "label_connections" in batch
|
|
assert "power_symbols" in batch
|
|
assert "no_connects" in batch
|
|
|
|
def test_existing_labels_skipped(self):
|
|
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="ALREADY_PLACED")
|
|
batch = generate_batch_plan([plan], existing_labels={"ALREADY_PLACED"})
|
|
assert "label_connections" not in batch
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests (require kicad-sch-api)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAutowireTool:
|
|
@requires_sch_api
|
|
def test_dry_run_returns_plan(self, populated_schematic):
|
|
from mckicad.tools.autowire import autowire_schematic
|
|
|
|
result = autowire_schematic(populated_schematic, dry_run=True)
|
|
assert result["success"] is True
|
|
assert result["dry_run"] is True
|
|
assert "strategy_summary" in result
|
|
assert "plan" in result
|
|
assert "batch_file" in result
|
|
|
|
# Verify batch file was written
|
|
assert os.path.isfile(result["batch_file"])
|
|
with open(result["batch_file"]) as f:
|
|
batch_data = json.load(f)
|
|
assert isinstance(batch_data, dict)
|
|
|
|
@requires_sch_api
|
|
def test_dry_run_does_not_modify(self, populated_schematic):
|
|
"""dry_run=True should not change the schematic file."""
|
|
with open(populated_schematic) as f:
|
|
original_content = f.read()
|
|
|
|
from mckicad.tools.autowire import autowire_schematic
|
|
|
|
autowire_schematic(populated_schematic, dry_run=True)
|
|
|
|
with open(populated_schematic) as f:
|
|
after_content = f.read()
|
|
|
|
assert original_content == after_content
|
|
|
|
@requires_sch_api
|
|
def test_invalid_schematic_path(self):
|
|
from mckicad.tools.autowire import autowire_schematic
|
|
|
|
result = autowire_schematic("/nonexistent/path.kicad_sch")
|
|
assert result["success"] is False
|
|
|
|
@requires_sch_api
|
|
def test_invalid_extension(self):
|
|
from mckicad.tools.autowire import autowire_schematic
|
|
|
|
result = autowire_schematic("/some/file.txt")
|
|
assert result["success"] is False
|