Add autowire_schematic tool for automated net wiring strategy
Some checks are pending
CI / Security Scan (push) Waiting to run
CI / Build Package (push) Blocked by required conditions
CI / Lint and Format (push) Waiting to run
CI / Test Python 3.11 on macos-latest (push) Waiting to run
CI / Test Python 3.12 on macos-latest (push) Waiting to run
CI / Test Python 3.13 on macos-latest (push) Waiting to run
CI / Test Python 3.10 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.11 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.12 on ubuntu-latest (push) Waiting to run
CI / Test Python 3.13 on ubuntu-latest (push) Waiting to run

Classifies unconnected nets into optimal wiring methods (direct wire,
local/global label, power symbol, no-connect) based on pin distance,
crossing count, fanout, and net name patterns. Delegates all file
manipulation to apply_batch, inheriting collision detection and label
serialization for free.

Strategy concepts informed by KICAD-autowire (MIT, arashmparsa).
This commit is contained in:
Ryan Malloy 2026-03-08 18:27:06 -06:00
parent eea91036f8
commit 1a3ffb42cd
7 changed files with 1204 additions and 0 deletions

View File

View File

@ -0,0 +1,105 @@
"""Convert NetPlan decisions into batch JSON for apply_batch.
Maps each wiring strategy to the batch schema that
``_apply_batch_operations()`` already consumes, so all file manipulation,
collision detection, and label serialization workarounds are inherited.
"""
from __future__ import annotations
from mckicad.autowire.strategy import NetPlan, WiringMethod
def generate_batch_plan(
plans: list[NetPlan],
existing_labels: set[str] | None = None,
) -> dict:
"""Convert a list of NetPlan decisions into a batch JSON dict.
Args:
plans: Classified net plans from ``classify_all_nets()``.
existing_labels: Net names that already have labels placed
(to avoid duplicates).
Returns:
Dict matching the ``apply_batch`` JSON schema with keys:
``wires``, ``label_connections``, ``power_symbols``, ``no_connects``.
"""
if existing_labels is None:
existing_labels = set()
wires: list[dict] = []
label_connections: list[dict] = []
power_symbols: list[dict] = []
no_connects: list[dict] = []
for plan in plans:
if plan.method == WiringMethod.SKIP:
continue
if plan.method == WiringMethod.DIRECT_WIRE:
# Two-pin net: wire between the pins
if len(plan.pins) == 2:
wires.append({
"from_ref": plan.pins[0].reference,
"from_pin": plan.pins[0].pin_number,
"to_ref": plan.pins[1].reference,
"to_pin": plan.pins[1].pin_number,
})
elif plan.method == WiringMethod.LOCAL_LABEL:
if plan.net_name in existing_labels:
continue
connections = [
{"ref": pin.reference, "pin": pin.pin_number}
for pin in plan.pins
]
label_connections.append({
"net": plan.net_name,
"global": False,
"connections": connections,
})
elif plan.method == WiringMethod.GLOBAL_LABEL:
if plan.net_name in existing_labels:
continue
connections = [
{"ref": pin.reference, "pin": pin.pin_number}
for pin in plan.pins
]
label_connections.append({
"net": plan.net_name,
"global": True,
"shape": plan.label_shape,
"connections": connections,
})
elif plan.method == WiringMethod.POWER_SYMBOL:
for pin in plan.pins:
entry: dict = {
"net": plan.net_name,
"pin_ref": pin.reference,
"pin_number": pin.pin_number,
}
if plan.power_lib_id:
entry["lib_id"] = plan.power_lib_id
power_symbols.append(entry)
elif plan.method == WiringMethod.NO_CONNECT:
for pin in plan.pins:
no_connects.append({
"pin_ref": pin.reference,
"pin_number": pin.pin_number,
})
result: dict = {}
if wires:
result["wires"] = wires
if label_connections:
result["label_connections"] = label_connections
if power_symbols:
result["power_symbols"] = power_symbols
if no_connects:
result["no_connects"] = no_connects
return result

View File

@ -0,0 +1,364 @@
"""Pure wiring strategy logic for automated schematic connectivity.
Decides how to connect each net (direct wire, local label, global label,
power symbol, no-connect, or skip) based on distance, fanout, crossing
count, and net name patterns.
Strategy concepts informed by KICAD-autowire (MIT, arashmparsa).
Implementation is original, built on mckicad's batch infrastructure.
No I/O or kicad-sch-api dependency all functions operate on plain
dataclasses and return deterministic results for a given input.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import StrEnum
import math
import re
class WiringMethod(StrEnum):
"""How a net should be physically connected in the schematic."""
DIRECT_WIRE = "direct_wire"
LOCAL_LABEL = "local_label"
GLOBAL_LABEL = "global_label"
POWER_SYMBOL = "power_symbol"
NO_CONNECT = "no_connect"
SKIP = "skip"
@dataclass
class PinInfo:
"""A component pin with its schematic-space coordinates and metadata."""
reference: str
pin_number: str
x: float
y: float
pin_type: str = ""
pin_name: str = ""
@dataclass
class NetPlan:
"""The chosen wiring strategy for a single net."""
net_name: str
method: WiringMethod
pins: list[PinInfo]
reason: str = ""
label_shape: str = "bidirectional"
power_lib_id: str | None = None
@dataclass
class WiringThresholds:
"""Tunable parameters for the wiring decision tree."""
direct_wire_max_distance: float = 50.0
crossing_threshold: int = 2
high_fanout_threshold: int = 5
label_min_distance: float = 10.0
power_net_patterns: list[str] = field(default_factory=lambda: [
"GND", "VCC", "VDD", "VSS",
"+3V3", "+5V", "+3.3V", "+12V", "+1.8V",
"VBUS", "VBAT",
])
# Pre-compiled pattern built from default list; rebuilt per-call when
# custom patterns are provided.
_DEFAULT_POWER_RE = re.compile(
r"^(GND[A-Z0-9_]*|VSS[A-Z0-9_]*|VCC[A-Z0-9_]*|VDD[A-Z0-9_]*"
r"|AGND|DGND|PGND|SGND"
r"|\+\d+V\d*|\+\d+\.\d+V"
r"|VBUS|VBAT)$",
re.IGNORECASE,
)
def _build_power_re(patterns: list[str]) -> re.Pattern[str]:
"""Build a regex from a list of power net name patterns."""
escaped = [re.escape(p) for p in patterns]
return re.compile(r"^(" + "|".join(escaped) + r")$", re.IGNORECASE)
def _is_power_net(
net_name: str,
pins: list[PinInfo],
thresholds: WiringThresholds,
) -> bool:
"""Check if a net is a power rail by name or pin type."""
name = net_name.strip()
# Check against default broad pattern first
if _DEFAULT_POWER_RE.match(name):
return True
# Check against user-provided patterns
if thresholds.power_net_patterns:
custom_re = _build_power_re(thresholds.power_net_patterns)
if custom_re.match(name):
return True
# Check pin types — if any pin is power_in or power_out, treat as power
for pin in pins:
pt = pin.pin_type.lower()
if pt in ("power_in", "power_out"):
return True
return False
def pin_distance(a: PinInfo, b: PinInfo) -> float:
"""Euclidean distance between two pins in mm."""
return math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2)
# -- Crossing estimation ---------------------------------------------------
@dataclass
class WireSegment:
"""An axis-aligned wire segment for crossing estimation."""
x1: float
y1: float
x2: float
y2: float
def _segments_cross(a: WireSegment, b: WireSegment) -> bool:
"""Check if two axis-aligned segments cross (perpendicular intersection).
A horizontal segment crosses a vertical segment when the vertical's
X falls within the horizontal's X range, and the horizontal's Y falls
within the vertical's Y range.
"""
a_horiz = abs(a.y1 - a.y2) < 0.01
a_vert = abs(a.x1 - a.x2) < 0.01
b_horiz = abs(b.y1 - b.y2) < 0.01
b_vert = abs(b.x1 - b.x2) < 0.01
if a_horiz and b_vert:
h, v = a, b
elif a_vert and b_horiz:
h, v = b, a
else:
return False
h_x_min = min(h.x1, h.x2)
h_x_max = max(h.x1, h.x2)
h_y = h.y1
v_y_min = min(v.y1, v.y2)
v_y_max = max(v.y1, v.y2)
v_x = v.x1
return (h_x_min < v_x < h_x_max) and (v_y_min < h_y < v_y_max)
def estimate_crossings(
proposed: WireSegment,
existing: list[WireSegment],
) -> int:
"""Count how many existing wire segments the proposed wire would cross."""
return sum(1 for seg in existing if _segments_cross(proposed, seg))
def _pin_pair_wire(a: PinInfo, b: PinInfo) -> WireSegment:
"""Create a hypothetical wire segment between two pins."""
return WireSegment(x1=a.x, y1=a.y, x2=b.x, y2=b.y)
# -- Classification --------------------------------------------------------
def classify_net(
net_name: str,
pins: list[PinInfo],
thresholds: WiringThresholds,
existing_wires: list[WireSegment] | None = None,
is_cross_sheet: bool = False,
) -> NetPlan:
"""Choose a wiring method for a single net.
Decision tree:
1. Power net? -> POWER_SYMBOL
2. Single-pin net? -> NO_CONNECT
3. Cross-sheet net? -> GLOBAL_LABEL
4. High fanout (>threshold pins)? -> GLOBAL_LABEL
5. Two-pin net:
- distance <= label_min_distance -> DIRECT_WIRE
- distance > direct_wire_max_distance -> LOCAL_LABEL
- else: estimate crossings -> >threshold? LOCAL_LABEL : DIRECT_WIRE
6. 3-4 pin net (below fanout threshold) -> LOCAL_LABEL
"""
if existing_wires is None:
existing_wires = []
# 1. Power net detection
if _is_power_net(net_name, pins, thresholds):
from mckicad.patterns._geometry import resolve_power_lib_id
return NetPlan(
net_name=net_name,
method=WiringMethod.POWER_SYMBOL,
pins=pins,
power_lib_id=resolve_power_lib_id(net_name),
reason="power net (name or pin type match)",
)
# 2. Single-pin net -> no-connect
if len(pins) == 1:
return NetPlan(
net_name=net_name,
method=WiringMethod.NO_CONNECT,
pins=pins,
reason="single-pin net",
)
# 3. Cross-sheet net -> global label
if is_cross_sheet:
return NetPlan(
net_name=net_name,
method=WiringMethod.GLOBAL_LABEL,
pins=pins,
label_shape="bidirectional",
reason="cross-sheet net",
)
# 4. High fanout -> global label
if len(pins) >= thresholds.high_fanout_threshold:
return NetPlan(
net_name=net_name,
method=WiringMethod.GLOBAL_LABEL,
pins=pins,
label_shape="bidirectional",
reason=f"high fanout ({len(pins)} pins >= {thresholds.high_fanout_threshold})",
)
# 5. Two-pin net: distance + crossing heuristic
if len(pins) == 2:
dist = pin_distance(pins[0], pins[1])
if dist <= thresholds.label_min_distance:
return NetPlan(
net_name=net_name,
method=WiringMethod.DIRECT_WIRE,
pins=pins,
reason=f"short distance ({dist:.1f}mm <= {thresholds.label_min_distance}mm)",
)
if dist > thresholds.direct_wire_max_distance:
return NetPlan(
net_name=net_name,
method=WiringMethod.LOCAL_LABEL,
pins=pins,
reason=f"long distance ({dist:.1f}mm > {thresholds.direct_wire_max_distance}mm)",
)
# Mid-range: check crossings
proposed = _pin_pair_wire(pins[0], pins[1])
crossings = estimate_crossings(proposed, existing_wires)
if crossings > thresholds.crossing_threshold:
return NetPlan(
net_name=net_name,
method=WiringMethod.LOCAL_LABEL,
pins=pins,
reason=f"crossing avoidance ({crossings} crossings > {thresholds.crossing_threshold})",
)
return NetPlan(
net_name=net_name,
method=WiringMethod.DIRECT_WIRE,
pins=pins,
reason=f"mid-range ({dist:.1f}mm), {crossings} crossing(s) within threshold",
)
# 6. 3-4 pin net -> local label (star topology)
return NetPlan(
net_name=net_name,
method=WiringMethod.LOCAL_LABEL,
pins=pins,
reason=f"multi-pin net ({len(pins)} pins, star topology cleaner with labels)",
)
def classify_all_nets(
nets: dict[str, list[list[str]]],
pin_positions: dict[str, dict[str, tuple[float, float]]],
thresholds: WiringThresholds,
pin_metadata: dict[str, dict[str, dict[str, str]]] | None = None,
existing_wires: list[WireSegment] | None = None,
cross_sheet_nets: set[str] | None = None,
exclude_nets: set[str] | None = None,
only_nets: set[str] | None = None,
) -> list[NetPlan]:
"""Classify all nets in a netlist using the wiring decision tree.
Args:
nets: Net name -> list of [ref, pin] pairs (from netlist parser).
pin_positions: ref -> pin_number -> (x, y) coordinate map.
thresholds: Tunable decision parameters.
pin_metadata: ref -> pin_number -> {pintype, pinfunction} from netlist.
existing_wires: Wire segments for crossing estimation.
cross_sheet_nets: Net names known to span multiple sheets.
exclude_nets: Nets to skip entirely.
only_nets: If provided, only classify these nets.
Returns:
List of NetPlan for each classifiable net.
"""
if pin_metadata is None:
pin_metadata = {}
if existing_wires is None:
existing_wires = []
if cross_sheet_nets is None:
cross_sheet_nets = set()
if exclude_nets is None:
exclude_nets = set()
plans: list[NetPlan] = []
for net_name, pin_pairs in nets.items():
if net_name in exclude_nets:
continue
if only_nets is not None and net_name not in only_nets:
continue
# Build PinInfo list from positions + metadata
pins: list[PinInfo] = []
for ref, pin_num in pin_pairs:
ref_pins = pin_positions.get(ref, {})
pos = ref_pins.get(pin_num)
if pos is None:
continue
meta = pin_metadata.get(ref, {}).get(pin_num, {})
pins.append(PinInfo(
reference=ref,
pin_number=pin_num,
x=pos[0],
y=pos[1],
pin_type=meta.get("pintype", ""),
pin_name=meta.get("pinfunction", ""),
))
if not pins:
continue
plan = classify_net(
net_name=net_name,
pins=pins,
thresholds=thresholds,
existing_wires=existing_wires,
is_cross_sheet=net_name in cross_sheet_nets,
)
plans.append(plan)
return plans

View File

@ -134,6 +134,13 @@ BATCH_LIMITS = {
"max_total_operations": 2000,
}
AUTOWIRE_DEFAULTS = {
"direct_wire_max_distance": 50.0,
"crossing_threshold": 2,
"high_fanout_threshold": 5,
"label_min_distance": 10.0,
}
DEFAULT_FOOTPRINTS = {
"R": [
"Resistor_SMD:R_0805_2012Metric",

View File

@ -48,6 +48,7 @@ from mckicad.resources import files, projects # noqa: E402, F401
from mckicad.resources import schematic as schematic_resources # noqa: E402, F401
from mckicad.tools import ( # noqa: E402, F401
analysis,
autowire,
batch,
bom,
drc,

View File

@ -0,0 +1,336 @@
"""Autowire tool for the mckicad MCP server.
Provides a single-call ``autowire_schematic`` tool that analyzes a
schematic's netlist, classifies each unconnected net into the best wiring
strategy, and optionally applies the result via the batch pipeline.
Strategy concepts informed by KICAD-autowire (MIT, arashmparsa).
Implementation is original, using mckicad's existing batch infrastructure.
"""
import json
import logging
import os
import subprocess
from typing import Any
from mckicad.autowire.planner import generate_batch_plan
from mckicad.autowire.strategy import (
WireSegment,
WiringThresholds,
classify_all_nets,
)
from mckicad.config import TIMEOUT_CONSTANTS
from mckicad.server import mcp
from mckicad.utils.kicad_cli import find_kicad_cli
logger = logging.getLogger(__name__)
_HAS_SCH_API = False
try:
from kicad_sch_api import load_schematic as _ksa_load
_HAS_SCH_API = True
except ImportError:
pass
def _require_sch_api() -> dict[str, Any] | None:
if not _HAS_SCH_API:
return {
"success": False,
"error": "kicad-sch-api is not installed. Install it with: uv add kicad-sch-api",
}
return None
def _expand(path: str) -> str:
return os.path.abspath(os.path.expanduser(path))
def _validate_schematic_path(path: str) -> dict[str, Any] | None:
if not path:
return {"success": False, "error": "Schematic path must be a non-empty string"}
expanded = os.path.expanduser(path)
if not expanded.endswith(".kicad_sch"):
return {"success": False, "error": f"Path must end with .kicad_sch, got: {path}"}
if not os.path.isfile(expanded):
return {"success": False, "error": f"Schematic file not found: {expanded}"}
return None
def _auto_export_netlist(schematic_path: str) -> tuple[str | None, str | None]:
"""Export a netlist via kicad-cli. Returns (path, error)."""
cli_path = find_kicad_cli()
if cli_path is None:
return None, "kicad-cli not found — provide netlist_path explicitly"
sidecar = os.path.join(os.path.dirname(schematic_path), ".mckicad")
os.makedirs(sidecar, exist_ok=True)
netlist_path = os.path.join(sidecar, "autowire_netlist.net")
cmd = [
cli_path, "sch", "export", "netlist",
"--format", "kicadsexpr",
"-o", netlist_path,
schematic_path,
]
try:
result = subprocess.run( # nosec B603
cmd,
capture_output=True,
text=True,
timeout=TIMEOUT_CONSTANTS["kicad_cli_export"],
check=False,
)
if os.path.isfile(netlist_path) and os.path.getsize(netlist_path) > 0:
return netlist_path, None
stderr = result.stderr.strip() if result.stderr else "netlist not created"
return None, f"kicad-cli netlist export failed: {stderr}"
except Exception as e:
return None, f"netlist export error: {e}"
def _build_pin_position_map(
all_pins: list[tuple[str, str, tuple[float, float]]],
) -> dict[str, dict[str, tuple[float, float]]]:
"""Convert the flat all_pins list to ref -> pin -> (x, y) map."""
result: dict[str, dict[str, tuple[float, float]]] = {}
for ref, pin_num, coord in all_pins:
result.setdefault(ref, {})[pin_num] = coord
return result
def _build_wire_segments(
wire_segment_dicts: list[dict[str, Any]],
) -> list[WireSegment]:
"""Convert connectivity state wire dicts to strategy WireSegments."""
segments: list[WireSegment] = []
for ws in wire_segment_dicts:
segments.append(WireSegment(
x1=ws["start"]["x"],
y1=ws["start"]["y"],
x2=ws["end"]["x"],
y2=ws["end"]["y"],
))
return segments
@mcp.tool()
def autowire_schematic(
schematic_path: str,
netlist_path: str | None = None,
dry_run: bool = True,
direct_wire_max_distance: float = 50.0,
crossing_threshold: int = 2,
high_fanout_threshold: int = 5,
exclude_nets: list[str] | None = None,
only_nets: list[str] | None = None,
exclude_refs: list[str] | None = None,
) -> dict[str, Any]:
"""Analyze unconnected nets and auto-wire them using optimal strategies.
Examines the schematic's netlist, classifies each unconnected net into
the best wiring method (direct wire, local label, global label, power
symbol, or no-connect), and optionally applies the wiring via the
batch pipeline.
The decision tree considers pin distance, wire crossing count, net
fanout, and power net naming patterns. Each strategy maps to existing
batch operations, inheriting collision detection and label placement.
By default runs in **dry_run** mode returns the plan without
modifying the schematic. Set ``dry_run=False`` to apply.
Args:
schematic_path: Path to a .kicad_sch file.
netlist_path: Path to a pre-exported netlist (.net file, kicad
s-expression format). Auto-exports via kicad-cli if omitted.
dry_run: Preview the plan without modifying the schematic (default
True). Set False to apply the wiring.
direct_wire_max_distance: Maximum pin distance (mm) for direct
wires. Beyond this, labels are used. Default 50.0.
crossing_threshold: Maximum acceptable wire crossings before
switching from direct wire to label. Default 2.
high_fanout_threshold: Pin count above which a net gets global
labels instead of local. Default 5.
exclude_nets: Net names to skip (e.g. already wired externally).
only_nets: If provided, only wire these specific nets.
exclude_refs: Component references to exclude from wiring.
Returns:
Dictionary with strategy summary (counts by method), the full
plan, batch JSON, and apply results (when dry_run=False).
"""
err = _require_sch_api()
if err:
return err
verr = _validate_schematic_path(schematic_path)
if verr:
return verr
schematic_path = _expand(schematic_path)
exclude_net_set = set(exclude_nets) if exclude_nets else set()
only_net_set = set(only_nets) if only_nets else None
exclude_ref_set = set(exclude_refs) if exclude_refs else set()
try:
# 1. Load schematic and build connectivity state
sch = _ksa_load(schematic_path)
from mckicad.tools.schematic_analysis import _build_connectivity_state
state = _build_connectivity_state(sch, schematic_path)
net_graph = state["net_graph"]
all_pins = state["all_pins"]
wire_segment_dicts = state["wire_segments"]
# 2. Get or auto-export netlist for pin metadata
if netlist_path:
netlist_path = _expand(netlist_path)
if not os.path.isfile(netlist_path):
return {"success": False, "error": f"Netlist file not found: {netlist_path}"}
else:
netlist_path, export_err = _auto_export_netlist(schematic_path)
if export_err or netlist_path is None:
return {"success": False, "error": export_err or "netlist export returned no path"}
# 3. Parse netlist
from mckicad.tools.netlist import _parse_kicad_sexp
assert netlist_path is not None # guaranteed by early returns above
with open(netlist_path) as f:
netlist_content = f.read()
parsed_netlist = _parse_kicad_sexp(netlist_content)
netlist_nets = parsed_netlist["nets"]
pin_metadata = parsed_netlist.get("pin_metadata", {})
# 4. Build pin position map from connectivity state
pin_positions = _build_pin_position_map(all_pins)
# 5. Filter out exclude_refs from pin positions
if exclude_ref_set:
pin_positions = {
ref: pins for ref, pins in pin_positions.items()
if ref not in exclude_ref_set
}
# 6. Identify already-connected nets (skip them)
connected_nets = set(net_graph.keys())
# Merge exclude sets
skip_nets = exclude_net_set | connected_nets
# 7. Build wire segments for crossing estimation
existing_wires = _build_wire_segments(wire_segment_dicts)
# 8. Classify all nets
thresholds = WiringThresholds(
direct_wire_max_distance=direct_wire_max_distance,
crossing_threshold=crossing_threshold,
high_fanout_threshold=high_fanout_threshold,
)
plans = classify_all_nets(
nets=netlist_nets,
pin_positions=pin_positions,
thresholds=thresholds,
pin_metadata=pin_metadata,
existing_wires=existing_wires,
exclude_nets=skip_nets,
only_nets=only_net_set,
)
# 9. Generate batch plan
existing_labels: set[str] = set()
for coord_text in state.get("label_at", {}).values():
existing_labels.add(coord_text)
batch_data = generate_batch_plan(plans, existing_labels)
# 10. Build strategy summary
method_counts: dict[str, int] = {}
for plan in plans:
method_counts[plan.method.value] = method_counts.get(plan.method.value, 0) + 1
plan_details = [
{
"net": p.net_name,
"method": p.method.value,
"pin_count": len(p.pins),
"reason": p.reason,
}
for p in plans
]
# 11. Write batch JSON to sidecar
sidecar = os.path.join(os.path.dirname(schematic_path), ".mckicad")
os.makedirs(sidecar, exist_ok=True)
batch_path = os.path.join(sidecar, "autowire_batch.json")
with open(batch_path, "w") as f:
json.dump(batch_data, f, indent=2)
result: dict[str, Any] = {
"success": True,
"dry_run": dry_run,
"strategy_summary": method_counts,
"total_nets_classified": len(plans),
"nets_skipped": len(skip_nets),
"plan": plan_details,
"batch_file": batch_path,
"schematic_path": schematic_path,
}
# 12. Apply if not dry_run
if not dry_run and batch_data:
from mckicad.tools.batch import (
_apply_batch_operations,
_register_project_libraries,
)
from mckicad.utils.sexp_parser import (
fix_property_private_keywords,
insert_sexp_before_close,
)
# Re-load schematic fresh for application
sch = _ksa_load(schematic_path)
_register_project_libraries(batch_data, schematic_path)
summary = _apply_batch_operations(sch, batch_data, schematic_path)
sch.save(schematic_path)
private_fixes = fix_property_private_keywords(schematic_path)
if private_fixes:
summary["property_private_fixes"] = private_fixes
pending_sexps = summary.pop("_pending_label_sexps", [])
if pending_sexps:
combined_sexp = "".join(pending_sexps)
insert_sexp_before_close(schematic_path, combined_sexp)
result["applied"] = {
"components_placed": summary.get("components_placed", 0),
"power_symbols_placed": summary.get("power_symbols_placed", 0),
"wires_placed": summary.get("wires_placed", 0),
"labels_placed": summary.get("labels_placed", 0),
"no_connects_placed": summary.get("no_connects_placed", 0),
"collisions_resolved": summary.get("collisions_resolved", 0),
"total_operations": summary.get("total_operations", 0),
}
logger.info(
"Autowire applied: %d operations to %s",
summary.get("total_operations", 0),
schematic_path,
)
elif not batch_data:
result["note"] = "No unconnected nets found — nothing to wire"
return result
except Exception as e:
logger.error("autowire_schematic failed: %s", e, exc_info=True)
return {"success": False, "error": str(e), "schematic_path": schematic_path}

391
tests/test_autowire.py Normal file
View File

@ -0,0 +1,391 @@
"""Tests for the autowire strategy, planner, and MCP tool."""
import json
import os
import pytest
from mckicad.autowire.planner import generate_batch_plan
from mckicad.autowire.strategy import (
NetPlan,
PinInfo,
WireSegment,
WiringMethod,
WiringThresholds,
classify_all_nets,
classify_net,
estimate_crossings,
pin_distance,
)
from tests.conftest import requires_sch_api
# ---------------------------------------------------------------------------
# Strategy unit tests
# ---------------------------------------------------------------------------
class TestPinDistance:
def test_same_point(self):
a = PinInfo(reference="R1", pin_number="1", x=100, y=100)
assert pin_distance(a, a) == 0.0
def test_horizontal(self):
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
b = PinInfo(reference="R2", pin_number="1", x=30, y=0)
assert pin_distance(a, b) == pytest.approx(30.0)
def test_diagonal(self):
a = PinInfo(reference="R1", pin_number="1", x=0, y=0)
b = PinInfo(reference="R2", pin_number="1", x=3, y=4)
assert pin_distance(a, b) == pytest.approx(5.0)
class TestCrossingEstimation:
def test_perpendicular_cross(self):
"""A horizontal wire crossing a vertical wire should count as 1."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=50, y1=0, x2=50, y2=100)]
assert estimate_crossings(proposed, existing) == 1
def test_parallel_no_cross(self):
"""Two parallel horizontal wires don't cross."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=0, y1=60, x2=100, y2=60)]
assert estimate_crossings(proposed, existing) == 0
def test_empty_existing(self):
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
assert estimate_crossings(proposed, []) == 0
def test_multiple_crossings(self):
"""A horizontal wire crossing three vertical wires."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [
WireSegment(x1=20, y1=0, x2=20, y2=100),
WireSegment(x1=50, y1=0, x2=50, y2=100),
WireSegment(x1=80, y1=0, x2=80, y2=100),
]
assert estimate_crossings(proposed, existing) == 3
def test_touching_endpoint_no_cross(self):
"""Wires that share an endpoint don't cross (strict inequality)."""
proposed = WireSegment(x1=0, y1=50, x2=100, y2=50)
existing = [WireSegment(x1=100, y1=0, x2=100, y2=100)]
assert estimate_crossings(proposed, existing) == 0
def test_collinear_no_cross(self):
"""Two collinear segments (same axis) don't cross."""
proposed = WireSegment(x1=0, y1=0, x2=100, y2=0)
existing = [WireSegment(x1=50, y1=0, x2=150, y2=0)]
assert estimate_crossings(proposed, existing) == 0
class TestWiringStrategy:
"""Test the classify_net decision tree."""
def _make_pins(self, count, spacing=5.0, pin_type="passive"):
return [
PinInfo(
reference=f"R{i+1}",
pin_number="1",
x=i * spacing,
y=0,
pin_type=pin_type,
)
for i in range(count)
]
def test_power_net_by_name(self):
pins = self._make_pins(2)
plan = classify_net("GND", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
assert plan.power_lib_id is not None
def test_power_net_vcc(self):
pins = self._make_pins(2)
plan = classify_net("VCC", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_power_net_plus_3v3(self):
pins = self._make_pins(2)
plan = classify_net("+3V3", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_power_net_by_pin_type(self):
pins = [
PinInfo(reference="U1", pin_number="1", x=0, y=0, pin_type="power_in"),
PinInfo(reference="C1", pin_number="2", x=10, y=0, pin_type="passive"),
]
plan = classify_net("CUSTOM_RAIL", pins, WiringThresholds())
assert plan.method == WiringMethod.POWER_SYMBOL
def test_single_pin_no_connect(self):
pins = self._make_pins(1)
plan = classify_net("NC_PIN", pins, WiringThresholds())
assert plan.method == WiringMethod.NO_CONNECT
def test_cross_sheet_global_label(self):
pins = self._make_pins(2)
plan = classify_net("SPI_CLK", pins, WiringThresholds(), is_cross_sheet=True)
assert plan.method == WiringMethod.GLOBAL_LABEL
def test_high_fanout_global_label(self):
pins = self._make_pins(6, spacing=10)
plan = classify_net("DATA_BUS", pins, WiringThresholds(high_fanout_threshold=5))
assert plan.method == WiringMethod.GLOBAL_LABEL
assert "fanout" in plan.reason
def test_short_distance_direct_wire(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=5, y=0),
]
plan = classify_net("NET1", pins, WiringThresholds())
assert plan.method == WiringMethod.DIRECT_WIRE
assert "short" in plan.reason
def test_long_distance_local_label(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=100, y=0),
]
plan = classify_net("NET2", pins, WiringThresholds(direct_wire_max_distance=50))
assert plan.method == WiringMethod.LOCAL_LABEL
assert "long" in plan.reason
def test_mid_range_no_crossings_direct_wire(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=30, y=0),
]
plan = classify_net("NET3", pins, WiringThresholds(), existing_wires=[])
assert plan.method == WiringMethod.DIRECT_WIRE
def test_mid_range_many_crossings_label(self):
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=0, y=30),
]
# Vertical wire from (0,0) to (0,30), crossing 3 horizontal wires
existing = [
WireSegment(x1=-10, y1=10, x2=10, y2=10),
WireSegment(x1=-10, y1=15, x2=10, y2=15),
WireSegment(x1=-10, y1=20, x2=10, y2=20),
]
plan = classify_net(
"NET4", pins,
WiringThresholds(crossing_threshold=2),
existing_wires=existing,
)
assert plan.method == WiringMethod.LOCAL_LABEL
assert "crossing" in plan.reason
def test_three_pin_net_local_label(self):
pins = self._make_pins(3, spacing=10)
plan = classify_net("NET5", pins, WiringThresholds())
assert plan.method == WiringMethod.LOCAL_LABEL
assert "multi-pin" in plan.reason
def test_custom_thresholds(self):
"""Custom thresholds override defaults."""
pins = [
PinInfo(reference="R1", pin_number="1", x=0, y=0),
PinInfo(reference="R2", pin_number="1", x=20, y=0),
]
# With default threshold (50mm), 20mm is mid-range -> direct wire
plan1 = classify_net("NET", pins, WiringThresholds())
assert plan1.method == WiringMethod.DIRECT_WIRE
# With tight threshold (15mm), 20mm exceeds max -> label
plan2 = classify_net("NET", pins, WiringThresholds(direct_wire_max_distance=15))
assert plan2.method == WiringMethod.LOCAL_LABEL
class TestClassifyAllNets:
def test_basic_classification(self):
nets = {
"GND": [["U1", "1"], ["C1", "2"]],
"SIG1": [["U1", "2"], ["R1", "1"]],
}
pin_positions = {
"U1": {"1": (100, 100), "2": (100, 110)},
"C1": {"2": (105, 100)},
"R1": {"1": (105, 110)},
}
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
methods = {p.net_name: p.method for p in plans}
assert methods["GND"] == WiringMethod.POWER_SYMBOL
assert methods["SIG1"] == WiringMethod.DIRECT_WIRE
def test_exclude_nets(self):
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
exclude_nets={"GND"},
)
assert all(p.net_name != "GND" for p in plans)
def test_only_nets(self):
nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]}
pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
only_nets={"SIG"},
)
assert len(plans) == 1
assert plans[0].net_name == "SIG"
def test_missing_pin_positions_skipped(self):
"""Nets with pins not in pin_positions are skipped."""
nets = {"SIG": [["U99", "1"], ["U99", "2"]]}
pin_positions: dict = {}
plans = classify_all_nets(nets, pin_positions, WiringThresholds())
assert len(plans) == 0
def test_pin_metadata_power_detection(self):
"""Pin metadata with power_in type triggers POWER_SYMBOL."""
nets = {"CUSTOM_PWR": [["U1", "1"], ["C1", "1"]]}
pin_positions = {"U1": {"1": (0, 0)}, "C1": {"1": (5, 0)}}
pin_metadata = {"U1": {"1": {"pintype": "power_in", "pinfunction": "VCC"}}}
plans = classify_all_nets(
nets, pin_positions, WiringThresholds(),
pin_metadata=pin_metadata,
)
assert plans[0].method == WiringMethod.POWER_SYMBOL
# ---------------------------------------------------------------------------
# Planner unit tests
# ---------------------------------------------------------------------------
class TestBatchPlanGeneration:
def _make_plan(self, method, net="NET1", pin_count=2, **kwargs):
pins = [
PinInfo(reference=f"R{i+1}", pin_number=str(i+1), x=i*10, y=0)
for i in range(pin_count)
]
return NetPlan(net_name=net, method=method, pins=pins, **kwargs)
def test_direct_wire(self):
plan = self._make_plan(WiringMethod.DIRECT_WIRE)
batch = generate_batch_plan([plan])
assert "wires" in batch
assert len(batch["wires"]) == 1
wire = batch["wires"][0]
assert wire["from_ref"] == "R1"
assert wire["to_ref"] == "R2"
def test_local_label(self):
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="SIG1")
batch = generate_batch_plan([plan])
assert "label_connections" in batch
lc = batch["label_connections"][0]
assert lc["net"] == "SIG1"
assert lc["global"] is False
assert len(lc["connections"]) == 2
def test_global_label(self):
plan = self._make_plan(
WiringMethod.GLOBAL_LABEL, net="SPI_CLK",
label_shape="input",
)
batch = generate_batch_plan([plan])
lc = batch["label_connections"][0]
assert lc["global"] is True
assert lc["shape"] == "input"
def test_power_symbol(self):
plan = self._make_plan(
WiringMethod.POWER_SYMBOL, net="GND",
power_lib_id="power:GND",
)
batch = generate_batch_plan([plan])
assert "power_symbols" in batch
assert len(batch["power_symbols"]) == 2 # one per pin
assert batch["power_symbols"][0]["net"] == "GND"
assert batch["power_symbols"][0]["lib_id"] == "power:GND"
def test_no_connect(self):
plan = self._make_plan(WiringMethod.NO_CONNECT, pin_count=1)
batch = generate_batch_plan([plan])
assert "no_connects" in batch
assert len(batch["no_connects"]) == 1
def test_skip_produces_nothing(self):
plan = self._make_plan(WiringMethod.SKIP)
batch = generate_batch_plan([plan])
assert batch == {}
def test_mixed_plan(self):
plans = [
self._make_plan(WiringMethod.DIRECT_WIRE, net="NET1"),
self._make_plan(WiringMethod.LOCAL_LABEL, net="NET2"),
self._make_plan(WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND"),
self._make_plan(WiringMethod.NO_CONNECT, net="NC1", pin_count=1),
]
batch = generate_batch_plan(plans)
assert "wires" in batch
assert "label_connections" in batch
assert "power_symbols" in batch
assert "no_connects" in batch
def test_existing_labels_skipped(self):
plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="ALREADY_PLACED")
batch = generate_batch_plan([plan], existing_labels={"ALREADY_PLACED"})
assert "label_connections" not in batch
# ---------------------------------------------------------------------------
# Integration tests (require kicad-sch-api)
# ---------------------------------------------------------------------------
class TestAutowireTool:
@requires_sch_api
def test_dry_run_returns_plan(self, populated_schematic):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic(populated_schematic, dry_run=True)
assert result["success"] is True
assert result["dry_run"] is True
assert "strategy_summary" in result
assert "plan" in result
assert "batch_file" in result
# Verify batch file was written
assert os.path.isfile(result["batch_file"])
with open(result["batch_file"]) as f:
batch_data = json.load(f)
assert isinstance(batch_data, dict)
@requires_sch_api
def test_dry_run_does_not_modify(self, populated_schematic):
"""dry_run=True should not change the schematic file."""
with open(populated_schematic) as f:
original_content = f.read()
from mckicad.tools.autowire import autowire_schematic
autowire_schematic(populated_schematic, dry_run=True)
with open(populated_schematic) as f:
after_content = f.read()
assert original_content == after_content
@requires_sch_api
def test_invalid_schematic_path(self):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic("/nonexistent/path.kicad_sch")
assert result["success"] is False
@requires_sch_api
def test_invalid_extension(self):
from mckicad.tools.autowire import autowire_schematic
result = autowire_schematic("/some/file.txt")
assert result["success"] is False