WireViz/src/wireviz/wv_harness.py
Ryan Malloy 65af27e0da Add notebook-ready API: graph cache invalidation, structured BOM, fragment merging
Three additions to support interactive/notebook-style harness building:

- Graph cache invalidation: _invalidate_graph() called from all mutating
  methods so svg/png output reflects latest state after mutations
- bom_list_dicts(): JSON-serializable BOM export as list of dicts
- parse(harness=, populate_bom=): append YAML fragments to existing
  harness for cell-by-cell building with deferred BOM population

Templates persist on the Harness object across parse() calls so
component definitions in one fragment are available to connections
in later fragments.

Includes 24 new tests covering all three features plus full incremental
workflow simulation. All 122 tests pass.
2026-02-13 07:08:21 -07:00

492 lines
19 KiB
Python

# -*- coding: utf-8 -*-
import os
import shutil
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import List, Union
#from distutils.spawn import find_executable
from graphviz import Graph
import wireviz.wv_colors
from wireviz.wv_bom import BomCategory, BomEntry, bom_list, print_bom_table
from wireviz.wv_dataclasses import (
AUTOGENERATED_PREFIX,
AdditionalBomItem,
Arrow,
ArrowWeight,
Cable,
Component,
Connector,
MateComponent,
MatePin,
Metadata,
Options,
Side,
TopLevelGraphicalComponent,
Tweak,
Image,
)
from wireviz.wv_graphviz import (
apply_dot_tweaks,
calculate_node_bgcolor,
gv_connector_loops,
gv_connector_shorts,
gv_edge_mate,
gv_edge_wire,
gv_edge_wire_inside,
gv_node_component,
parse_arrow_str,
set_dot_basics,
)
from wireviz.wv_output import (
embed_svg_images,
embed_svg_images_file,
generate_html_output,
)
from wireviz.wv_utils import OLD_CONNECTOR_ATTR, bom2tsv, check_old, file_write_text, getAddCompFromRef
@dataclass
class Harness:
metadata: Metadata
options: Options
tweak: Tweak
additional_bom_items: List[AdditionalBomItem] = field(default_factory=list)
def __post_init__(self):
self.connectors = {}
self.cables = {}
self.mates = []
self.bom = defaultdict(dict)
self.additional_bom_items = []
# persistent template storage for fragment merging
self._template_connectors = {}
self._template_cables = {}
def _invalidate_graph(self):
"""Clear cached graph so next access regenerates it."""
self._graph = None
def add_connector(self, designator: str, *args, **kwargs) -> None:
check_old(f"Connector '{designator}'", OLD_CONNECTOR_ATTR, kwargs)
conn = Connector(designator=designator, *args, **kwargs)
self.connectors[designator] = conn
self._invalidate_graph()
def add_cable(self, designator: str, *args, **kwargs) -> None:
cbl = Cable(designator=designator, *args, **kwargs)
self.cables[designator] = cbl
self._invalidate_graph()
def add_additional_bom_item(self, item: dict) -> None:
new_item = AdditionalBomItem(**item)
self.additional_bom_items.append(new_item)
self._invalidate_graph()
def add_mate_pin(self, from_name, from_pin, to_name, to_pin, arrow_str) -> None:
from_con = self.connectors[from_name]
from_pin_obj = from_con.pin_objects[from_pin]
to_con = self.connectors[to_name]
to_pin_obj = to_con.pin_objects[to_pin]
arrow = Arrow(direction=parse_arrow_str(arrow_str), weight=ArrowWeight.SINGLE)
self.mates.append(MatePin(from_pin_obj, to_pin_obj, arrow))
self.connectors[from_name].activate_pin(
from_pin, Side.RIGHT, is_connection=False
)
self.connectors[to_name].activate_pin(to_pin, Side.LEFT, is_connection=False)
self._invalidate_graph()
def add_mate_component(self, from_name, to_name, arrow_str) -> None:
arrow = Arrow(direction=parse_arrow_str(arrow_str), weight=ArrowWeight.SINGLE)
self.mates.append(MateComponent(from_name, to_name, arrow))
self._invalidate_graph()
def populate_bom(self): # called once harness creation is complete
# helper lists
all_toplevel_items = (
list(self.connectors.values())
+ list(self.cables.values())
+ self.additional_bom_items
)
all_subitems = [
subitem
for item in all_toplevel_items
for subitem in item.additional_components
]
all_bom_relevant_items = (
list(self.connectors.values())
+ [cable for cable in self.cables.values() if cable.category != "bundle"]
+ [
wire
for cable in self.cables.values()
if cable.category == "bundle"
for wire in cable.wire_objects.values()
]
+ all_subitems
)
# add items to BOM
for item in all_toplevel_items:
self._add_to_internal_bom(item) # nested subitems are also handled
# sort BOM by category first, then alphabetically by description within category
self.bom = dict(
sorted(
self.bom.items(),
key=lambda x: (
x[1]["category"],
x[0].description,
), # x[0] = key, x[1] = value
)
)
# assign BOM IDs
for id, key in enumerate(self.bom.keys(), 1):
self.bom[key]["id"] = id
# set BOM IDs within components (for BOM bubbles)
for item in all_bom_relevant_items:
if item.ignore_in_bom:
continue
if not item.bom_hash in self.bom:
print(f"{item}'s hash' not found in BOM dict.") # Should not happen
continue
item.bom_id = self.bom[item.bom_hash]["id"]
def _add_to_internal_bom(self, item: Component):
if item.ignore_in_bom:
return
def _add(hash, qty, designator=None, category=None):
bom_entry = self.bom[hash]
# initialize missing fields
if not "qty" in bom_entry:
bom_entry["qty"] = 0
if not "designators" in bom_entry:
bom_entry["designators"] = set()
# update fields
bom_entry["qty"] += qty
if designator is None:
designator_list = []
elif isinstance(designator, list):
designator_list = designator
else:
designator_list = [designator]
for des in designator_list:
if des and not des.startswith(AUTOGENERATED_PREFIX):
bom_entry["designators"].add(des)
bom_entry["category"] = category
if isinstance(item, TopLevelGraphicalComponent):
if isinstance(item, Connector):
cat = BomCategory.CONNECTOR
elif isinstance(item, Cable):
if item.category == "bundle":
cat = BomCategory.WIRE
else:
cat = BomCategory.CABLE
else:
cat = ""
if item.category == "bundle":
# wires of a bundle are added as individual BOM entries
for subitem in item.wire_objects.values():
_add(
hash=subitem.bom_hash,
qty=item.qty, # should be 1
designator=item.designator, # inherit from parent item
category=cat,
)
else:
_add(
hash=item.bom_hash,
qty=item.qty, # should be 1
designator=item.designator,
category=cat,
)
if item.additional_components:
item.compute_qty_multipliers()
for comp in item.additional_components:
if comp.ignore_in_bom:
continue
if comp.sum_amounts_in_bom:
if comp.amount_computed:
total_qty = comp.qty_computed * comp.amount_computed.number
else:
total_qty = comp.qty_computed
else:
total_qty = comp.qty_computed
_add(
hash=comp.bom_hash,
designator=item.designator,
qty=total_qty,
# no explicit qty specified; assume qty = 1
# used to simplify add.comp. table within parent node
# e.g. show "10 mm Heatshrink" instead of "1x 10 mm Heatshrink"
category=BomCategory.ADDITIONAL_INSIDE,
)
elif isinstance(item, AdditionalBomItem):
cat = BomCategory.ADDITIONAL_OUTSIDE
_add(
hash=item.bom_hash,
qty=item.qty,
designator=None,
category=cat,
)
else:
raise Exception(f"Unknown type of item:\n{item}")
def connect(
self,
from_name: str,
from_pin: Union[int, str],
via_name: str,
via_wire: Union[int, str],
to_name: str,
to_pin: Union[int, str],
) -> None:
# resolve pin labels to pin numbers via Connector.resolve_pin()
for name, pin, is_from in [
(from_name, from_pin, True),
(to_name, to_pin, False),
]:
if name is not None and name in self.connectors:
resolved = self.connectors[name].resolve_pin(pin)
if is_from:
from_pin = resolved
else:
to_pin = resolved
# check via cable
if via_name in self.cables:
cable = self.cables[via_name]
# check if provided name is ambiguous
if via_wire in cable.colors and via_wire in cable.wirelabels:
if cable.colors.index(via_wire) != cable.wirelabels.index(via_wire):
raise Exception(
f"{via_name}:{via_wire} is defined both in colors and wirelabels, "
"for different wires."
)
# TODO: Maybe issue a warning if present in both lists
# but referencing the same wire?
if via_wire in cable.colors:
if cable.colors.count(via_wire) > 1:
raise Exception(
f"{via_name}:{via_wire} is used for more than one wire."
)
# list index starts at 0, wire IDs start at 1
via_wire = cable.colors.index(via_wire) + 1
elif via_wire in cable.wirelabels:
if cable.wirelabels.count(via_wire) > 1:
raise Exception(
f"{via_name}:{via_wire} is used for more than one wire."
)
via_wire = (
cable.wirelabels.index(via_wire) + 1
) # list index starts at 0, wire IDs start at 1
# perform the actual connection
if from_name is not None:
from_con = self.connectors[from_name]
from_pin_obj = from_con.pin_objects[from_pin]
else:
from_pin_obj = None
if to_name is not None:
to_con = self.connectors[to_name]
to_pin_obj = to_con.pin_objects[to_pin]
else:
to_pin_obj = None
self.cables[via_name]._connect(from_pin_obj, via_wire, to_pin_obj)
if from_name in self.connectors:
self.connectors[from_name].activate_pin(from_pin, Side.RIGHT)
if to_name in self.connectors:
self.connectors[to_name].activate_pin(to_pin, Side.LEFT)
self._invalidate_graph()
def create_graph(self) -> Graph:
dot = Graph()
set_dot_basics(dot, self.options)
for connector in self.connectors.values():
# generate connector node
gv_html = gv_node_component(connector)
gv_html.update_attribs(
bgcolor=calculate_node_bgcolor(connector, self.options)
)
dot.node(
connector.designator,
label=f"<\n{gv_html}\n>",
shape="box",
style="filled",
href=connector.url if isinstance(connector.url, str) else '',
)
# generate edges for connector loops
if len(connector.loops) > 0:
dot.attr("edge", color="#000000", href='')
loops = gv_connector_loops(connector)
for head, tail, color in loops:
dot.edge(head, tail, color = color, label = " ", noLabel="noLabel")
# generate edges for connector shorts
if len(connector.shorts) > 0:
dot.attr("edge", color="#000000", href='')
shorts = gv_connector_shorts(connector)
for head, tail, color in shorts:
dot.edge(head, tail,
color=color,
straight="straight",
addPTS=".18", # Size of the point at the end of the straight line/edge, it also enables the drawing of it
colorPTS=color.replace("#FFFFFF:", ""),
headclip="false", tailclip="false")
# determine if there are double- or triple-colored wires in the harness;
# if so, pad single-color wires to make all wires of equal thickness
wire_is_multicolor = [
len(wire.color) > 1
for cable in self.cables.values()
for wire in cable.wire_objects.values()
]
if any(wire_is_multicolor):
wireviz.wv_colors.padding_amount = 3
else:
wireviz.wv_colors.padding_amount = 1
for cable in self.cables.values():
# generate cable node
# TODO: PN info for bundles (per wire)
gv_html = gv_node_component(cable)
gv_html.update_attribs(bgcolor=calculate_node_bgcolor(cable, self.options))
style = "filled,dashed" if cable.category == "bundle" else "filled"
dot.node(
cable.designator,
label=f"<\n{gv_html}\n>",
shape="box",
style=style,
href=cable.url if isinstance(cable.url, str) else '',
)
# generate wire edges between component nodes and cable nodes
for connection in cable._connections:
color, l1, l2, r1, r2 = gv_edge_wire(self, cable, connection)
# determine per-wire URL for clickable edges
wire_url = ''
if connection.via is not None:
wire_idx = connection.via.index
if isinstance(cable.url, list):
wire_url = cable.url[wire_idx] if wire_idx < len(cable.url) else ''
wire_url = wire_url or ''
elif isinstance(cable.url, str):
wire_url = cable.url
dot.attr("edge", color=color, href=wire_url)
if not (l1, l2) == (None, None):
dot.edge(l1, l2)
if not (r1, r2) == (None, None):
dot.edge(r1, r2)
for color, we, ww in gv_edge_wire_inside(cable):
if not (we, ww) == (None, None):
dot.edge(we, ww, color=color, straight="straight", href='')
for mate in self.mates:
color, dir, code_from, code_to = gv_edge_mate(mate)
dot.attr("edge", color=color, style="dashed", dir=dir, href='')
dot.edge(code_from, code_to)
apply_dot_tweaks(dot, self.tweak)
return dot
# cache for the GraphViz Graph object
# do not access directly, use self.graph instead
_graph = None
@property
def graph(self):
if not self._graph: # no cached graph exists, generate one
self._graph = self.create_graph()
return self._graph # return cached graph
@property
def png(self):
from io import BytesIO
graph = self.graph
data = BytesIO()
data.write(graph.pipe(format="png"))
data.seek(0)
return data.read()
@property
def svg(self): # TODO?: Verify xml encoding="utf-8" in SVG?
graph = self.graph
return embed_svg_images(graph.pipe(format="svg").decode("utf-8"), Path.cwd())
def graphRender(self, type, filename, graph):
# Chack if the needed commands are existing
if shutil.which("dot") and shutil.which("gvpr") and shutil.which("neato"):
# Set enviorments variable to path of this file
os.environ['GVPRPATH'] = str(Path(__file__).parent)
# Export the gv output to a temporay file
graph.save(filename=f"{filename}_tmp.gv")
# Run the vomand and generait the output
os.system(f"dot {filename}_tmp.gv | gvpr -q -cf wv_gvpr.gvpr | neato -n2 -T{type} -o {filename}.{type}")
# Remove the temporary file
os.remove(f"{filename}_tmp.gv")
else:
print('The "dot", "gvpr" and "neato" comand where not found on the system, use old methode of generaiton, this may lead to not wanted output.')
graph.render(filename=filename) # old rendering methode, befor jumper implementations
def output(
self,
filename: Union[str, Path],
view: bool = False,
cleanup: bool = True,
fmt: tuple = ("html", "png", "svg", "tsv"),
) -> None:
# graphical output
graph = self.graph
for f in fmt:
if f in ("png", "svg", "html", "pdf"):
if f == "html": # if HTML format is specified,
f = "svg" # generate SVG for embedding into HTML
# SVG file will be renamed/deleted later
_filename = f"{filename}.tmp" if f == "svg" else filename
# TODO: prevent rendering SVG twice when both SVG and HTML are specified
graph.format = f
self.graphRender(f, _filename, graph)
# embed images into SVG output
if "svg" in fmt or "html" in fmt:
embed_svg_images_file(f"{filename}.tmp.svg")
# GraphViz output
if "gv" in fmt:
graph.save(filename=f"{filename}.gv")
# Print the needed comand for generaitong an output
filename_str = str(filename)
shutil.copyfile(str(Path(__file__).parent).replace('\\', '/') + "/wv_gvpr.gvpr", filename_str + "_wv_gvpr.gvpr")
print(f"Use: dot {filename_str}.gv | gvpr -q -cf {filename_str}_wv_gvpr.gvpr | neato -n2 -T<type> -o {filename_str}.<type>")
# BOM output
bomlist = bom_list(self.bom)
# bomlist = [[]]
if "tsv" in fmt:
tsv = bom2tsv(bomlist)
file_write_text(f"{filename}.bom.tsv", tsv)
if "csv" in fmt:
# TODO: implement CSV output (preferrably using CSV library)
print("CSV output is not yet supported")
# HTML output
if "html" in fmt:
generate_html_output(filename, bomlist, self.metadata, self.options)
# PDF output is handled by GraphViz in the format loop above
# delete SVG if not needed
if "html" in fmt and not "svg" in fmt:
# SVG file was just needed to generate HTML
Path(f"{filename}.tmp.svg").unlink()
elif "svg" in fmt:
Path(f"{filename}.tmp.svg").replace(f"{filename}.svg")