diff --git a/src/mckicad/autowire/__init__.py b/src/mckicad/autowire/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mckicad/autowire/planner.py b/src/mckicad/autowire/planner.py new file mode 100644 index 0000000..633330e --- /dev/null +++ b/src/mckicad/autowire/planner.py @@ -0,0 +1,105 @@ +"""Convert NetPlan decisions into batch JSON for apply_batch. + +Maps each wiring strategy to the batch schema that +``_apply_batch_operations()`` already consumes, so all file manipulation, +collision detection, and label serialization workarounds are inherited. +""" + +from __future__ import annotations + +from mckicad.autowire.strategy import NetPlan, WiringMethod + + +def generate_batch_plan( + plans: list[NetPlan], + existing_labels: set[str] | None = None, +) -> dict: + """Convert a list of NetPlan decisions into a batch JSON dict. + + Args: + plans: Classified net plans from ``classify_all_nets()``. + existing_labels: Net names that already have labels placed + (to avoid duplicates). + + Returns: + Dict matching the ``apply_batch`` JSON schema with keys: + ``wires``, ``label_connections``, ``power_symbols``, ``no_connects``. + """ + if existing_labels is None: + existing_labels = set() + + wires: list[dict] = [] + label_connections: list[dict] = [] + power_symbols: list[dict] = [] + no_connects: list[dict] = [] + + for plan in plans: + if plan.method == WiringMethod.SKIP: + continue + + if plan.method == WiringMethod.DIRECT_WIRE: + # Two-pin net: wire between the pins + if len(plan.pins) == 2: + wires.append({ + "from_ref": plan.pins[0].reference, + "from_pin": plan.pins[0].pin_number, + "to_ref": plan.pins[1].reference, + "to_pin": plan.pins[1].pin_number, + }) + + elif plan.method == WiringMethod.LOCAL_LABEL: + if plan.net_name in existing_labels: + continue + connections = [ + {"ref": pin.reference, "pin": pin.pin_number} + for pin in plan.pins + ] + label_connections.append({ + "net": plan.net_name, + "global": False, + "connections": connections, + }) + + elif plan.method == WiringMethod.GLOBAL_LABEL: + if plan.net_name in existing_labels: + continue + connections = [ + {"ref": pin.reference, "pin": pin.pin_number} + for pin in plan.pins + ] + label_connections.append({ + "net": plan.net_name, + "global": True, + "shape": plan.label_shape, + "connections": connections, + }) + + elif plan.method == WiringMethod.POWER_SYMBOL: + for pin in plan.pins: + entry: dict = { + "net": plan.net_name, + "pin_ref": pin.reference, + "pin_number": pin.pin_number, + } + if plan.power_lib_id: + entry["lib_id"] = plan.power_lib_id + power_symbols.append(entry) + + elif plan.method == WiringMethod.NO_CONNECT: + for pin in plan.pins: + no_connects.append({ + "pin_ref": pin.reference, + "pin_number": pin.pin_number, + }) + + result: dict = {} + if wires: + result["wires"] = wires + if label_connections: + result["label_connections"] = label_connections + if power_symbols: + result["power_symbols"] = power_symbols + if no_connects: + result["no_connects"] = no_connects + + return result diff --git a/src/mckicad/autowire/strategy.py b/src/mckicad/autowire/strategy.py new file mode 100644 index 0000000..9676550 --- /dev/null +++ b/src/mckicad/autowire/strategy.py @@ -0,0 +1,364 @@ +"""Pure wiring strategy logic for automated schematic connectivity. + +Decides how to connect each net (direct wire, local label, global label, +power symbol, no-connect, or skip) based on distance, fanout, crossing +count, and net name patterns. + +Strategy concepts informed by KICAD-autowire (MIT, arashmparsa). +Implementation is original, built on mckicad's batch infrastructure. + +No I/O or kicad-sch-api dependency — all functions operate on plain +dataclasses and return deterministic results for a given input. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import StrEnum +import math +import re + + +class WiringMethod(StrEnum): + """How a net should be physically connected in the schematic.""" + + DIRECT_WIRE = "direct_wire" + LOCAL_LABEL = "local_label" + GLOBAL_LABEL = "global_label" + POWER_SYMBOL = "power_symbol" + NO_CONNECT = "no_connect" + SKIP = "skip" + + +@dataclass +class PinInfo: + """A component pin with its schematic-space coordinates and metadata.""" + + reference: str + pin_number: str + x: float + y: float + pin_type: str = "" + pin_name: str = "" + + +@dataclass +class NetPlan: + """The chosen wiring strategy for a single net.""" + + net_name: str + method: WiringMethod + pins: list[PinInfo] + reason: str = "" + label_shape: str = "bidirectional" + power_lib_id: str | None = None + + +@dataclass +class WiringThresholds: + """Tunable parameters for the wiring decision tree.""" + + direct_wire_max_distance: float = 50.0 + crossing_threshold: int = 2 + high_fanout_threshold: int = 5 + label_min_distance: float = 10.0 + power_net_patterns: list[str] = field(default_factory=lambda: [ + "GND", "VCC", "VDD", "VSS", + "+3V3", "+5V", "+3.3V", "+12V", "+1.8V", + "VBUS", "VBAT", + ]) + + +# Pre-compiled pattern built from default list; rebuilt per-call when +# custom patterns are provided. +_DEFAULT_POWER_RE = re.compile( + r"^(GND[A-Z0-9_]*|VSS[A-Z0-9_]*|VCC[A-Z0-9_]*|VDD[A-Z0-9_]*" + r"|AGND|DGND|PGND|SGND" + r"|\+\d+V\d*|\+\d+\.\d+V" + r"|VBUS|VBAT)$", + re.IGNORECASE, +) + + +def _build_power_re(patterns: list[str]) -> re.Pattern[str]: + """Build a regex from a list of power net name patterns.""" + escaped = [re.escape(p) for p in patterns] + return re.compile(r"^(" + "|".join(escaped) + r")$", re.IGNORECASE) + + +def _is_power_net( + net_name: str, + pins: list[PinInfo], + thresholds: WiringThresholds, +) -> bool: + """Check if a net is a power rail by name or pin type.""" + name = net_name.strip() + + # Check against default broad pattern first + if _DEFAULT_POWER_RE.match(name): + return True + + # Check against user-provided patterns + if thresholds.power_net_patterns: + custom_re = _build_power_re(thresholds.power_net_patterns) + if custom_re.match(name): + return True + + # Check pin types — if any pin is power_in or power_out, treat as power + for pin in pins: + pt = pin.pin_type.lower() + if pt in ("power_in", "power_out"): + return True + + return False + + +def pin_distance(a: PinInfo, b: PinInfo) -> float: + """Euclidean distance between two pins in mm.""" + return math.sqrt((a.x - b.x) ** 2 + (a.y - b.y) ** 2) + + +# -- Crossing estimation --------------------------------------------------- + + +@dataclass +class WireSegment: + """An axis-aligned wire segment for crossing estimation.""" + + x1: float + y1: float + x2: float + y2: float + + +def _segments_cross(a: WireSegment, b: WireSegment) -> bool: + """Check if two axis-aligned segments cross (perpendicular intersection). + + A horizontal segment crosses a vertical segment when the vertical's + X falls within the horizontal's X range, and the horizontal's Y falls + within the vertical's Y range. + """ + a_horiz = abs(a.y1 - a.y2) < 0.01 + a_vert = abs(a.x1 - a.x2) < 0.01 + b_horiz = abs(b.y1 - b.y2) < 0.01 + b_vert = abs(b.x1 - b.x2) < 0.01 + + if a_horiz and b_vert: + h, v = a, b + elif a_vert and b_horiz: + h, v = b, a + else: + return False + + h_x_min = min(h.x1, h.x2) + h_x_max = max(h.x1, h.x2) + h_y = h.y1 + + v_y_min = min(v.y1, v.y2) + v_y_max = max(v.y1, v.y2) + v_x = v.x1 + + return (h_x_min < v_x < h_x_max) and (v_y_min < h_y < v_y_max) + + +def estimate_crossings( + proposed: WireSegment, + existing: list[WireSegment], +) -> int: + """Count how many existing wire segments the proposed wire would cross.""" + return sum(1 for seg in existing if _segments_cross(proposed, seg)) + + +def _pin_pair_wire(a: PinInfo, b: PinInfo) -> WireSegment: + """Create a hypothetical wire segment between two pins.""" + return WireSegment(x1=a.x, y1=a.y, x2=b.x, y2=b.y) + + +# -- Classification -------------------------------------------------------- + + +def classify_net( + net_name: str, + pins: list[PinInfo], + thresholds: WiringThresholds, + existing_wires: list[WireSegment] | None = None, + is_cross_sheet: bool = False, +) -> NetPlan: + """Choose a wiring method for a single net. + + Decision tree: + 1. Power net? -> POWER_SYMBOL + 2. Single-pin net? -> NO_CONNECT + 3. Cross-sheet net? -> GLOBAL_LABEL + 4. High fanout (>threshold pins)? -> GLOBAL_LABEL + 5. Two-pin net: + - distance <= label_min_distance -> DIRECT_WIRE + - distance > direct_wire_max_distance -> LOCAL_LABEL + - else: estimate crossings -> >threshold? LOCAL_LABEL : DIRECT_WIRE + 6. 3-4 pin net (below fanout threshold) -> LOCAL_LABEL + """ + if existing_wires is None: + existing_wires = [] + + # 1. Power net detection + if _is_power_net(net_name, pins, thresholds): + from mckicad.patterns._geometry import resolve_power_lib_id + + return NetPlan( + net_name=net_name, + method=WiringMethod.POWER_SYMBOL, + pins=pins, + power_lib_id=resolve_power_lib_id(net_name), + reason="power net (name or pin type match)", + ) + + # 2. Single-pin net -> no-connect + if len(pins) == 1: + return NetPlan( + net_name=net_name, + method=WiringMethod.NO_CONNECT, + pins=pins, + reason="single-pin net", + ) + + # 3. Cross-sheet net -> global label + if is_cross_sheet: + return NetPlan( + net_name=net_name, + method=WiringMethod.GLOBAL_LABEL, + pins=pins, + label_shape="bidirectional", + reason="cross-sheet net", + ) + + # 4. High fanout -> global label + if len(pins) >= thresholds.high_fanout_threshold: + return NetPlan( + net_name=net_name, + method=WiringMethod.GLOBAL_LABEL, + pins=pins, + label_shape="bidirectional", + reason=f"high fanout ({len(pins)} pins >= {thresholds.high_fanout_threshold})", + ) + + # 5. Two-pin net: distance + crossing heuristic + if len(pins) == 2: + dist = pin_distance(pins[0], pins[1]) + + if dist <= thresholds.label_min_distance: + return NetPlan( + net_name=net_name, + method=WiringMethod.DIRECT_WIRE, + pins=pins, + reason=f"short distance ({dist:.1f}mm <= {thresholds.label_min_distance}mm)", + ) + + if dist > thresholds.direct_wire_max_distance: + return NetPlan( + net_name=net_name, + method=WiringMethod.LOCAL_LABEL, + pins=pins, + reason=f"long distance ({dist:.1f}mm > {thresholds.direct_wire_max_distance}mm)", + ) + + # Mid-range: check crossings + proposed = _pin_pair_wire(pins[0], pins[1]) + crossings = estimate_crossings(proposed, existing_wires) + + if crossings > thresholds.crossing_threshold: + return NetPlan( + net_name=net_name, + method=WiringMethod.LOCAL_LABEL, + pins=pins, + reason=f"crossing avoidance ({crossings} crossings > {thresholds.crossing_threshold})", + ) + + return NetPlan( + net_name=net_name, + method=WiringMethod.DIRECT_WIRE, + pins=pins, + reason=f"mid-range ({dist:.1f}mm), {crossings} crossing(s) within threshold", + ) + + # 6. 3-4 pin net -> local label (star topology) + return NetPlan( + net_name=net_name, + method=WiringMethod.LOCAL_LABEL, + pins=pins, + reason=f"multi-pin net ({len(pins)} pins, star topology cleaner with labels)", + ) + + +def classify_all_nets( + nets: dict[str, list[list[str]]], + pin_positions: dict[str, dict[str, tuple[float, float]]], + thresholds: WiringThresholds, + pin_metadata: dict[str, dict[str, dict[str, str]]] | None = None, + existing_wires: list[WireSegment] | None = None, + cross_sheet_nets: set[str] | None = None, + exclude_nets: set[str] | None = None, + only_nets: set[str] | None = None, +) -> list[NetPlan]: + """Classify all nets in a netlist using the wiring decision tree. + + Args: + nets: Net name -> list of [ref, pin] pairs (from netlist parser). + pin_positions: ref -> pin_number -> (x, y) coordinate map. + thresholds: Tunable decision parameters. + pin_metadata: ref -> pin_number -> {pintype, pinfunction} from netlist. + existing_wires: Wire segments for crossing estimation. + cross_sheet_nets: Net names known to span multiple sheets. + exclude_nets: Nets to skip entirely. + only_nets: If provided, only classify these nets. + + Returns: + List of NetPlan for each classifiable net. + """ + if pin_metadata is None: + pin_metadata = {} + if existing_wires is None: + existing_wires = [] + if cross_sheet_nets is None: + cross_sheet_nets = set() + if exclude_nets is None: + exclude_nets = set() + + plans: list[NetPlan] = [] + + for net_name, pin_pairs in nets.items(): + if net_name in exclude_nets: + continue + if only_nets is not None and net_name not in only_nets: + continue + + # Build PinInfo list from positions + metadata + pins: list[PinInfo] = [] + for ref, pin_num in pin_pairs: + ref_pins = pin_positions.get(ref, {}) + pos = ref_pins.get(pin_num) + if pos is None: + continue + + meta = pin_metadata.get(ref, {}).get(pin_num, {}) + pins.append(PinInfo( + reference=ref, + pin_number=pin_num, + x=pos[0], + y=pos[1], + pin_type=meta.get("pintype", ""), + pin_name=meta.get("pinfunction", ""), + )) + + if not pins: + continue + + plan = classify_net( + net_name=net_name, + pins=pins, + thresholds=thresholds, + existing_wires=existing_wires, + is_cross_sheet=net_name in cross_sheet_nets, + ) + plans.append(plan) + + return plans diff --git a/src/mckicad/config.py b/src/mckicad/config.py index e4ba141..5e5d5e3 100644 --- a/src/mckicad/config.py +++ b/src/mckicad/config.py @@ -134,6 +134,13 @@ BATCH_LIMITS = { "max_total_operations": 2000, } +AUTOWIRE_DEFAULTS = { + "direct_wire_max_distance": 50.0, + "crossing_threshold": 2, + "high_fanout_threshold": 5, + "label_min_distance": 10.0, +} + DEFAULT_FOOTPRINTS = { "R": [ "Resistor_SMD:R_0805_2012Metric", diff --git a/src/mckicad/server.py b/src/mckicad/server.py index e725170..9523d02 100644 --- a/src/mckicad/server.py +++ b/src/mckicad/server.py @@ -48,6 +48,7 @@ from mckicad.resources import files, projects # noqa: E402, F401 from mckicad.resources import schematic as schematic_resources # noqa: E402, F401 from mckicad.tools import ( # noqa: E402, F401 analysis, + autowire, batch, bom, drc, diff --git a/src/mckicad/tools/autowire.py b/src/mckicad/tools/autowire.py new file mode 100644 index 0000000..2d12251 --- /dev/null +++ b/src/mckicad/tools/autowire.py @@ -0,0 +1,336 @@ +"""Autowire tool for the mckicad MCP server. + +Provides a single-call ``autowire_schematic`` tool that analyzes a +schematic's netlist, classifies each unconnected net into the best wiring +strategy, and optionally applies the result via the batch pipeline. + +Strategy concepts informed by KICAD-autowire (MIT, arashmparsa). +Implementation is original, using mckicad's existing batch infrastructure. +""" + +import json +import logging +import os +import subprocess +from typing import Any + +from mckicad.autowire.planner import generate_batch_plan +from mckicad.autowire.strategy import ( + WireSegment, + WiringThresholds, + classify_all_nets, +) +from mckicad.config import TIMEOUT_CONSTANTS +from mckicad.server import mcp +from mckicad.utils.kicad_cli import find_kicad_cli + +logger = logging.getLogger(__name__) + +_HAS_SCH_API = False +try: + from kicad_sch_api import load_schematic as _ksa_load + + _HAS_SCH_API = True +except ImportError: + pass + + +def _require_sch_api() -> dict[str, Any] | None: + if not _HAS_SCH_API: + return { + "success": False, + "error": "kicad-sch-api is not installed. Install it with: uv add kicad-sch-api", + } + return None + + +def _expand(path: str) -> str: + return os.path.abspath(os.path.expanduser(path)) + + +def _validate_schematic_path(path: str) -> dict[str, Any] | None: + if not path: + return {"success": False, "error": "Schematic path must be a non-empty string"} + expanded = os.path.expanduser(path) + if not expanded.endswith(".kicad_sch"): + return {"success": False, "error": f"Path must end with .kicad_sch, got: {path}"} + if not os.path.isfile(expanded): + return {"success": False, "error": f"Schematic file not found: {expanded}"} + return None + + +def _auto_export_netlist(schematic_path: str) -> tuple[str | None, str | None]: + """Export a netlist via kicad-cli. Returns (path, error).""" + cli_path = find_kicad_cli() + if cli_path is None: + return None, "kicad-cli not found — provide netlist_path explicitly" + + sidecar = os.path.join(os.path.dirname(schematic_path), ".mckicad") + os.makedirs(sidecar, exist_ok=True) + netlist_path = os.path.join(sidecar, "autowire_netlist.net") + + cmd = [ + cli_path, "sch", "export", "netlist", + "--format", "kicadsexpr", + "-o", netlist_path, + schematic_path, + ] + + try: + result = subprocess.run( # nosec B603 + cmd, + capture_output=True, + text=True, + timeout=TIMEOUT_CONSTANTS["kicad_cli_export"], + check=False, + ) + if os.path.isfile(netlist_path) and os.path.getsize(netlist_path) > 0: + return netlist_path, None + stderr = result.stderr.strip() if result.stderr else "netlist not created" + return None, f"kicad-cli netlist export failed: {stderr}" + except Exception as e: + return None, f"netlist export error: {e}" + + +def _build_pin_position_map( + all_pins: list[tuple[str, str, tuple[float, float]]], +) -> dict[str, dict[str, tuple[float, float]]]: + """Convert the flat all_pins list to ref -> pin -> (x, y) map.""" + result: dict[str, dict[str, tuple[float, float]]] = {} + for ref, pin_num, coord in all_pins: + result.setdefault(ref, {})[pin_num] = coord + return result + + +def _build_wire_segments( + wire_segment_dicts: list[dict[str, Any]], +) -> list[WireSegment]: + """Convert connectivity state wire dicts to strategy WireSegments.""" + segments: list[WireSegment] = [] + for ws in wire_segment_dicts: + segments.append(WireSegment( + x1=ws["start"]["x"], + y1=ws["start"]["y"], + x2=ws["end"]["x"], + y2=ws["end"]["y"], + )) + return segments + + +@mcp.tool() +def autowire_schematic( + schematic_path: str, + netlist_path: str | None = None, + dry_run: bool = True, + direct_wire_max_distance: float = 50.0, + crossing_threshold: int = 2, + high_fanout_threshold: int = 5, + exclude_nets: list[str] | None = None, + only_nets: list[str] | None = None, + exclude_refs: list[str] | None = None, +) -> dict[str, Any]: + """Analyze unconnected nets and auto-wire them using optimal strategies. + + Examines the schematic's netlist, classifies each unconnected net into + the best wiring method (direct wire, local label, global label, power + symbol, or no-connect), and optionally applies the wiring via the + batch pipeline. + + The decision tree considers pin distance, wire crossing count, net + fanout, and power net naming patterns. Each strategy maps to existing + batch operations, inheriting collision detection and label placement. + + By default runs in **dry_run** mode — returns the plan without + modifying the schematic. Set ``dry_run=False`` to apply. + + Args: + schematic_path: Path to a .kicad_sch file. + netlist_path: Path to a pre-exported netlist (.net file, kicad + s-expression format). Auto-exports via kicad-cli if omitted. + dry_run: Preview the plan without modifying the schematic (default + True). Set False to apply the wiring. + direct_wire_max_distance: Maximum pin distance (mm) for direct + wires. Beyond this, labels are used. Default 50.0. + crossing_threshold: Maximum acceptable wire crossings before + switching from direct wire to label. Default 2. + high_fanout_threshold: Pin count above which a net gets global + labels instead of local. Default 5. + exclude_nets: Net names to skip (e.g. already wired externally). + only_nets: If provided, only wire these specific nets. + exclude_refs: Component references to exclude from wiring. + + Returns: + Dictionary with strategy summary (counts by method), the full + plan, batch JSON, and apply results (when dry_run=False). + """ + err = _require_sch_api() + if err: + return err + + verr = _validate_schematic_path(schematic_path) + if verr: + return verr + + schematic_path = _expand(schematic_path) + exclude_net_set = set(exclude_nets) if exclude_nets else set() + only_net_set = set(only_nets) if only_nets else None + exclude_ref_set = set(exclude_refs) if exclude_refs else set() + + try: + # 1. Load schematic and build connectivity state + sch = _ksa_load(schematic_path) + + from mckicad.tools.schematic_analysis import _build_connectivity_state + + state = _build_connectivity_state(sch, schematic_path) + net_graph = state["net_graph"] + all_pins = state["all_pins"] + wire_segment_dicts = state["wire_segments"] + + # 2. Get or auto-export netlist for pin metadata + if netlist_path: + netlist_path = _expand(netlist_path) + if not os.path.isfile(netlist_path): + return {"success": False, "error": f"Netlist file not found: {netlist_path}"} + else: + netlist_path, export_err = _auto_export_netlist(schematic_path) + if export_err or netlist_path is None: + return {"success": False, "error": export_err or "netlist export returned no path"} + + # 3. Parse netlist + from mckicad.tools.netlist import _parse_kicad_sexp + + assert netlist_path is not None # guaranteed by early returns above + with open(netlist_path) as f: + netlist_content = f.read() + parsed_netlist = _parse_kicad_sexp(netlist_content) + netlist_nets = parsed_netlist["nets"] + pin_metadata = parsed_netlist.get("pin_metadata", {}) + + # 4. Build pin position map from connectivity state + pin_positions = _build_pin_position_map(all_pins) + + # 5. Filter out exclude_refs from pin positions + if exclude_ref_set: + pin_positions = { + ref: pins for ref, pins in pin_positions.items() + if ref not in exclude_ref_set + } + + # 6. Identify already-connected nets (skip them) + connected_nets = set(net_graph.keys()) + + # Merge exclude sets + skip_nets = exclude_net_set | connected_nets + + # 7. Build wire segments for crossing estimation + existing_wires = _build_wire_segments(wire_segment_dicts) + + # 8. Classify all nets + thresholds = WiringThresholds( + direct_wire_max_distance=direct_wire_max_distance, + crossing_threshold=crossing_threshold, + high_fanout_threshold=high_fanout_threshold, + ) + + plans = classify_all_nets( + nets=netlist_nets, + pin_positions=pin_positions, + thresholds=thresholds, + pin_metadata=pin_metadata, + existing_wires=existing_wires, + exclude_nets=skip_nets, + only_nets=only_net_set, + ) + + # 9. Generate batch plan + existing_labels: set[str] = set() + for coord_text in state.get("label_at", {}).values(): + existing_labels.add(coord_text) + + batch_data = generate_batch_plan(plans, existing_labels) + + # 10. Build strategy summary + method_counts: dict[str, int] = {} + for plan in plans: + method_counts[plan.method.value] = method_counts.get(plan.method.value, 0) + 1 + + plan_details = [ + { + "net": p.net_name, + "method": p.method.value, + "pin_count": len(p.pins), + "reason": p.reason, + } + for p in plans + ] + + # 11. Write batch JSON to sidecar + sidecar = os.path.join(os.path.dirname(schematic_path), ".mckicad") + os.makedirs(sidecar, exist_ok=True) + batch_path = os.path.join(sidecar, "autowire_batch.json") + with open(batch_path, "w") as f: + json.dump(batch_data, f, indent=2) + + result: dict[str, Any] = { + "success": True, + "dry_run": dry_run, + "strategy_summary": method_counts, + "total_nets_classified": len(plans), + "nets_skipped": len(skip_nets), + "plan": plan_details, + "batch_file": batch_path, + "schematic_path": schematic_path, + } + + # 12. Apply if not dry_run + if not dry_run and batch_data: + from mckicad.tools.batch import ( + _apply_batch_operations, + _register_project_libraries, + ) + from mckicad.utils.sexp_parser import ( + fix_property_private_keywords, + insert_sexp_before_close, + ) + + # Re-load schematic fresh for application + sch = _ksa_load(schematic_path) + _register_project_libraries(batch_data, schematic_path) + + summary = _apply_batch_operations(sch, batch_data, schematic_path) + + sch.save(schematic_path) + + private_fixes = fix_property_private_keywords(schematic_path) + if private_fixes: + summary["property_private_fixes"] = private_fixes + + pending_sexps = summary.pop("_pending_label_sexps", []) + if pending_sexps: + combined_sexp = "".join(pending_sexps) + insert_sexp_before_close(schematic_path, combined_sexp) + + result["applied"] = { + "components_placed": summary.get("components_placed", 0), + "power_symbols_placed": summary.get("power_symbols_placed", 0), + "wires_placed": summary.get("wires_placed", 0), + "labels_placed": summary.get("labels_placed", 0), + "no_connects_placed": summary.get("no_connects_placed", 0), + "collisions_resolved": summary.get("collisions_resolved", 0), + "total_operations": summary.get("total_operations", 0), + } + + logger.info( + "Autowire applied: %d operations to %s", + summary.get("total_operations", 0), + schematic_path, + ) + elif not batch_data: + result["note"] = "No unconnected nets found — nothing to wire" + + return result + + except Exception as e: + logger.error("autowire_schematic failed: %s", e, exc_info=True) + return {"success": False, "error": str(e), "schematic_path": schematic_path} diff --git a/tests/test_autowire.py b/tests/test_autowire.py new file mode 100644 index 0000000..1cc2530 --- /dev/null +++ b/tests/test_autowire.py @@ -0,0 +1,391 @@ +"""Tests for the autowire strategy, planner, and MCP tool.""" + +import json +import os + +import pytest + +from mckicad.autowire.planner import generate_batch_plan +from mckicad.autowire.strategy import ( + NetPlan, + PinInfo, + WireSegment, + WiringMethod, + WiringThresholds, + classify_all_nets, + classify_net, + estimate_crossings, + pin_distance, +) +from tests.conftest import requires_sch_api + +# --------------------------------------------------------------------------- +# Strategy unit tests +# --------------------------------------------------------------------------- + + +class TestPinDistance: + def test_same_point(self): + a = PinInfo(reference="R1", pin_number="1", x=100, y=100) + assert pin_distance(a, a) == 0.0 + + def test_horizontal(self): + a = PinInfo(reference="R1", pin_number="1", x=0, y=0) + b = PinInfo(reference="R2", pin_number="1", x=30, y=0) + assert pin_distance(a, b) == pytest.approx(30.0) + + def test_diagonal(self): + a = PinInfo(reference="R1", pin_number="1", x=0, y=0) + b = PinInfo(reference="R2", pin_number="1", x=3, y=4) + assert pin_distance(a, b) == pytest.approx(5.0) + + +class TestCrossingEstimation: + def test_perpendicular_cross(self): + """A horizontal wire crossing a vertical wire should count as 1.""" + proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) + existing = [WireSegment(x1=50, y1=0, x2=50, y2=100)] + assert estimate_crossings(proposed, existing) == 1 + + def test_parallel_no_cross(self): + """Two parallel horizontal wires don't cross.""" + proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) + existing = [WireSegment(x1=0, y1=60, x2=100, y2=60)] + assert estimate_crossings(proposed, existing) == 0 + + def test_empty_existing(self): + proposed = WireSegment(x1=0, y1=0, x2=100, y2=0) + assert estimate_crossings(proposed, []) == 0 + + def test_multiple_crossings(self): + """A horizontal wire crossing three vertical wires.""" + proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) + existing = [ + WireSegment(x1=20, y1=0, x2=20, y2=100), + WireSegment(x1=50, y1=0, x2=50, y2=100), + WireSegment(x1=80, y1=0, x2=80, y2=100), + ] + assert estimate_crossings(proposed, existing) == 3 + + def test_touching_endpoint_no_cross(self): + """Wires that share an endpoint don't cross (strict inequality).""" + proposed = WireSegment(x1=0, y1=50, x2=100, y2=50) + existing = [WireSegment(x1=100, y1=0, x2=100, y2=100)] + assert estimate_crossings(proposed, existing) == 0 + + def test_collinear_no_cross(self): + """Two collinear segments (same axis) don't cross.""" + proposed = WireSegment(x1=0, y1=0, x2=100, y2=0) + existing = [WireSegment(x1=50, y1=0, x2=150, y2=0)] + assert estimate_crossings(proposed, existing) == 0 + + +class TestWiringStrategy: + """Test the classify_net decision tree.""" + + def _make_pins(self, count, spacing=5.0, pin_type="passive"): + return [ + PinInfo( + reference=f"R{i+1}", + pin_number="1", + x=i * spacing, + y=0, + pin_type=pin_type, + ) + for i in range(count) + ] + + def test_power_net_by_name(self): + pins = self._make_pins(2) + plan = classify_net("GND", pins, WiringThresholds()) + assert plan.method == WiringMethod.POWER_SYMBOL + assert plan.power_lib_id is not None + + def test_power_net_vcc(self): + pins = self._make_pins(2) + plan = classify_net("VCC", pins, WiringThresholds()) + assert plan.method == WiringMethod.POWER_SYMBOL + + def test_power_net_plus_3v3(self): + pins = self._make_pins(2) + plan = classify_net("+3V3", pins, WiringThresholds()) + assert plan.method == WiringMethod.POWER_SYMBOL + + def test_power_net_by_pin_type(self): + pins = [ + PinInfo(reference="U1", pin_number="1", x=0, y=0, pin_type="power_in"), + PinInfo(reference="C1", pin_number="2", x=10, y=0, pin_type="passive"), + ] + plan = classify_net("CUSTOM_RAIL", pins, WiringThresholds()) + assert plan.method == WiringMethod.POWER_SYMBOL + + def test_single_pin_no_connect(self): + pins = self._make_pins(1) + plan = classify_net("NC_PIN", pins, WiringThresholds()) + assert plan.method == WiringMethod.NO_CONNECT + + def test_cross_sheet_global_label(self): + pins = self._make_pins(2) + plan = classify_net("SPI_CLK", pins, WiringThresholds(), is_cross_sheet=True) + assert plan.method == WiringMethod.GLOBAL_LABEL + + def test_high_fanout_global_label(self): + pins = self._make_pins(6, spacing=10) + plan = classify_net("DATA_BUS", pins, WiringThresholds(high_fanout_threshold=5)) + assert plan.method == WiringMethod.GLOBAL_LABEL + assert "fanout" in plan.reason + + def test_short_distance_direct_wire(self): + pins = [ + PinInfo(reference="R1", pin_number="1", x=0, y=0), + PinInfo(reference="R2", pin_number="1", x=5, y=0), + ] + plan = classify_net("NET1", pins, WiringThresholds()) + assert plan.method == WiringMethod.DIRECT_WIRE + assert "short" in plan.reason + + def test_long_distance_local_label(self): + pins = [ + PinInfo(reference="R1", pin_number="1", x=0, y=0), + PinInfo(reference="R2", pin_number="1", x=100, y=0), + ] + plan = classify_net("NET2", pins, WiringThresholds(direct_wire_max_distance=50)) + assert plan.method == WiringMethod.LOCAL_LABEL + assert "long" in plan.reason + + def test_mid_range_no_crossings_direct_wire(self): + pins = [ + PinInfo(reference="R1", pin_number="1", x=0, y=0), + PinInfo(reference="R2", pin_number="1", x=30, y=0), + ] + plan = classify_net("NET3", pins, WiringThresholds(), existing_wires=[]) + assert plan.method == WiringMethod.DIRECT_WIRE + + def test_mid_range_many_crossings_label(self): + pins = [ + PinInfo(reference="R1", pin_number="1", x=0, y=0), + PinInfo(reference="R2", pin_number="1", x=0, y=30), + ] + # Vertical wire from (0,0) to (0,30), crossing 3 horizontal wires + existing = [ + WireSegment(x1=-10, y1=10, x2=10, y2=10), + WireSegment(x1=-10, y1=15, x2=10, y2=15), + WireSegment(x1=-10, y1=20, x2=10, y2=20), + ] + plan = classify_net( + "NET4", pins, + WiringThresholds(crossing_threshold=2), + existing_wires=existing, + ) + assert plan.method == WiringMethod.LOCAL_LABEL + assert "crossing" in plan.reason + + def test_three_pin_net_local_label(self): + pins = self._make_pins(3, spacing=10) + plan = classify_net("NET5", pins, WiringThresholds()) + assert plan.method == WiringMethod.LOCAL_LABEL + assert "multi-pin" in plan.reason + + def test_custom_thresholds(self): + """Custom thresholds override defaults.""" + pins = [ + PinInfo(reference="R1", pin_number="1", x=0, y=0), + PinInfo(reference="R2", pin_number="1", x=20, y=0), + ] + # With default threshold (50mm), 20mm is mid-range -> direct wire + plan1 = classify_net("NET", pins, WiringThresholds()) + assert plan1.method == WiringMethod.DIRECT_WIRE + + # With tight threshold (15mm), 20mm exceeds max -> label + plan2 = classify_net("NET", pins, WiringThresholds(direct_wire_max_distance=15)) + assert plan2.method == WiringMethod.LOCAL_LABEL + + +class TestClassifyAllNets: + def test_basic_classification(self): + nets = { + "GND": [["U1", "1"], ["C1", "2"]], + "SIG1": [["U1", "2"], ["R1", "1"]], + } + pin_positions = { + "U1": {"1": (100, 100), "2": (100, 110)}, + "C1": {"2": (105, 100)}, + "R1": {"1": (105, 110)}, + } + plans = classify_all_nets(nets, pin_positions, WiringThresholds()) + + methods = {p.net_name: p.method for p in plans} + assert methods["GND"] == WiringMethod.POWER_SYMBOL + assert methods["SIG1"] == WiringMethod.DIRECT_WIRE + + def test_exclude_nets(self): + nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]} + pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}} + plans = classify_all_nets( + nets, pin_positions, WiringThresholds(), + exclude_nets={"GND"}, + ) + assert all(p.net_name != "GND" for p in plans) + + def test_only_nets(self): + nets = {"GND": [["U1", "1"]], "SIG": [["U1", "2"], ["R1", "1"]]} + pin_positions = {"U1": {"1": (0, 0), "2": (0, 10)}, "R1": {"1": (5, 10)}} + plans = classify_all_nets( + nets, pin_positions, WiringThresholds(), + only_nets={"SIG"}, + ) + assert len(plans) == 1 + assert plans[0].net_name == "SIG" + + def test_missing_pin_positions_skipped(self): + """Nets with pins not in pin_positions are skipped.""" + nets = {"SIG": [["U99", "1"], ["U99", "2"]]} + pin_positions: dict = {} + plans = classify_all_nets(nets, pin_positions, WiringThresholds()) + assert len(plans) == 0 + + def test_pin_metadata_power_detection(self): + """Pin metadata with power_in type triggers POWER_SYMBOL.""" + nets = {"CUSTOM_PWR": [["U1", "1"], ["C1", "1"]]} + pin_positions = {"U1": {"1": (0, 0)}, "C1": {"1": (5, 0)}} + pin_metadata = {"U1": {"1": {"pintype": "power_in", "pinfunction": "VCC"}}} + plans = classify_all_nets( + nets, pin_positions, WiringThresholds(), + pin_metadata=pin_metadata, + ) + assert plans[0].method == WiringMethod.POWER_SYMBOL + + +# --------------------------------------------------------------------------- +# Planner unit tests +# --------------------------------------------------------------------------- + + +class TestBatchPlanGeneration: + def _make_plan(self, method, net="NET1", pin_count=2, **kwargs): + pins = [ + PinInfo(reference=f"R{i+1}", pin_number=str(i+1), x=i*10, y=0) + for i in range(pin_count) + ] + return NetPlan(net_name=net, method=method, pins=pins, **kwargs) + + def test_direct_wire(self): + plan = self._make_plan(WiringMethod.DIRECT_WIRE) + batch = generate_batch_plan([plan]) + assert "wires" in batch + assert len(batch["wires"]) == 1 + wire = batch["wires"][0] + assert wire["from_ref"] == "R1" + assert wire["to_ref"] == "R2" + + def test_local_label(self): + plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="SIG1") + batch = generate_batch_plan([plan]) + assert "label_connections" in batch + lc = batch["label_connections"][0] + assert lc["net"] == "SIG1" + assert lc["global"] is False + assert len(lc["connections"]) == 2 + + def test_global_label(self): + plan = self._make_plan( + WiringMethod.GLOBAL_LABEL, net="SPI_CLK", + label_shape="input", + ) + batch = generate_batch_plan([plan]) + lc = batch["label_connections"][0] + assert lc["global"] is True + assert lc["shape"] == "input" + + def test_power_symbol(self): + plan = self._make_plan( + WiringMethod.POWER_SYMBOL, net="GND", + power_lib_id="power:GND", + ) + batch = generate_batch_plan([plan]) + assert "power_symbols" in batch + assert len(batch["power_symbols"]) == 2 # one per pin + assert batch["power_symbols"][0]["net"] == "GND" + assert batch["power_symbols"][0]["lib_id"] == "power:GND" + + def test_no_connect(self): + plan = self._make_plan(WiringMethod.NO_CONNECT, pin_count=1) + batch = generate_batch_plan([plan]) + assert "no_connects" in batch + assert len(batch["no_connects"]) == 1 + + def test_skip_produces_nothing(self): + plan = self._make_plan(WiringMethod.SKIP) + batch = generate_batch_plan([plan]) + assert batch == {} + + def test_mixed_plan(self): + plans = [ + self._make_plan(WiringMethod.DIRECT_WIRE, net="NET1"), + self._make_plan(WiringMethod.LOCAL_LABEL, net="NET2"), + self._make_plan(WiringMethod.POWER_SYMBOL, net="GND", power_lib_id="power:GND"), + self._make_plan(WiringMethod.NO_CONNECT, net="NC1", pin_count=1), + ] + batch = generate_batch_plan(plans) + assert "wires" in batch + assert "label_connections" in batch + assert "power_symbols" in batch + assert "no_connects" in batch + + def test_existing_labels_skipped(self): + plan = self._make_plan(WiringMethod.LOCAL_LABEL, net="ALREADY_PLACED") + batch = generate_batch_plan([plan], existing_labels={"ALREADY_PLACED"}) + assert "label_connections" not in batch + + +# --------------------------------------------------------------------------- +# Integration tests (require kicad-sch-api) +# --------------------------------------------------------------------------- + + +class TestAutowireTool: + @requires_sch_api + def test_dry_run_returns_plan(self, populated_schematic): + from mckicad.tools.autowire import autowire_schematic + + result = autowire_schematic(populated_schematic, dry_run=True) + assert result["success"] is True + assert result["dry_run"] is True + assert "strategy_summary" in result + assert "plan" in result + assert "batch_file" in result + + # Verify batch file was written + assert os.path.isfile(result["batch_file"]) + with open(result["batch_file"]) as f: + batch_data = json.load(f) + assert isinstance(batch_data, dict) + + @requires_sch_api + def test_dry_run_does_not_modify(self, populated_schematic): + """dry_run=True should not change the schematic file.""" + with open(populated_schematic) as f: + original_content = f.read() + + from mckicad.tools.autowire import autowire_schematic + + autowire_schematic(populated_schematic, dry_run=True) + + with open(populated_schematic) as f: + after_content = f.read() + + assert original_content == after_content + + @requires_sch_api + def test_invalid_schematic_path(self): + from mckicad.tools.autowire import autowire_schematic + + result = autowire_schematic("/nonexistent/path.kicad_sch") + assert result["success"] is False + + @requires_sch_api + def test_invalid_extension(self): + from mckicad.tools.autowire import autowire_schematic + + result = autowire_schematic("/some/file.txt") + assert result["success"] is False