"""Graphviz-based system graph visualization utilities.
This module renders SysSimX systems as Graphviz directed graphs with:
- grouped component clusters,
- port-aware record labels,
- event connections styled distinctly,
- algebraic loop connections highlighted, and
- optional unit labels on edges.
"""
from __future__ import annotations
import colorsys
import hashlib
from collections.abc import Iterable
from graphviz import Digraph
from IPython.display import display
from ..core.base import CoSimComponent
from ..core.port import PortSpec, PortType
from ..system.connection import Connection
from ..system.graph import collect_active_outputs
from ..system.system import System
# -----------------------------------------------------------------------------
# Styling constants
# -----------------------------------------------------------------------------
COLOR_EDGE = "#111827"
COLOR_EDGE_LABEL = "#374151"
COLOR_EVENT = "#0b84ff"
COLOR_FEEDTHROUGH = "#ef4444"
COLOR_LEGEND_BORDER = "#d1d5db"
COLOR_LEGEND_BG = "#f9fafb"
FONT_DEFAULT = "Helvetica"
FONT_BOLD = "Helvetica-Bold"
DEFAULT_GROUP_PALETTE = {
"Reference": "#dbeafe", # light blue
"Sensors": "#dcfce7", # light green
"Control": "#fee2e2", # light red
"Actuator": "#fef3c7", # light yellow
"Plant": "#e5e7eb", # light gray
}
# ----------------------------------------------------------------------------
# Private helper functions
# ----------------------------------------------------------------------------
[docs]
def _auto_color_for_group(
group_name: str, base_saturation: float = 0.35, base_lightness: float = 0.85
) -> str:
"""Generate a deterministic pastel color for a group name.
Args:
group_name: Name of the group to color.
base_saturation: Base saturation for the generated color.
base_lightness: Base lightness for the generated color.
Returns:
Hex color string (e.g., "#aabbcc").
"""
hash_val = int(hashlib.md5(group_name.encode()).hexdigest(), 16)
hue = (hash_val % 360) / 360.0
saturation = base_saturation + ((hash_val % 20) - 10) * 0.01
lightness = base_lightness + (((hash_val >> 8) % 20) - 10) * 0.005
r, g, b = colorsys.hls_to_rgb(hue, lightness, saturation)
return f"#{int(r * 255):02x}{int(g * 255):02x}{int(b * 255):02x}"
[docs]
def _build_palette(system: System) -> dict[str, str]:
"""Build a group color palette for the system.
Args:
system: System whose groups are inspected.
Returns:
Mapping of group name to hex color.
"""
palette = dict(DEFAULT_GROUP_PALETTE)
for group in system.groups:
if group not in palette:
palette[group] = _auto_color_for_group(group)
return palette
[docs]
def _is_event_port(port: PortSpec) -> bool:
"""Return True if a port is an event-type port.
Args:
port: Port specification to inspect.
Returns:
True if the port is not a real-valued port.
"""
return port.type != PortType.REAL
[docs]
def _build_ports_column(
port_names: Iterable[str],
event_ports: set[str],
feedthrough_ports: set[str] | None = None,
) -> str:
"""Build an HTML table cell for a port column.
Args:
port_names: Iterable of port names in display order.
event_ports: Set of event port names.
feedthrough_ports: Optional set of direct-feedthrough output ports.
Returns:
HTML fragment for the column cell.
"""
port_list = list(port_names)
feedthrough_ports = feedthrough_ports or set()
def label_for(name: str) -> str:
if name in event_ports:
return _format_port_label(name, COLOR_EVENT)
if name in feedthrough_ports:
return _format_port_label(name, COLOR_FEEDTHROUGH)
return _format_port_label(name)
if not port_list:
return "<TD></TD>"
if len(port_list) == 1:
name = port_list[0]
return f'<TD PORT="{name}">{label_for(name)}</TD>'
rows = []
for name in port_list:
rows.append(f'<TR><TD PORT="{name}">{label_for(name)}</TD></TR>')
rows_str = "".join(rows)
return (
'<TD><TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="2">'
f"{rows_str}</TABLE></TD>"
)
[docs]
def _record_label_for_component(comp: CoSimComponent, execution_idx: int = -1) -> str:
"""Build an HTML-like record label for a component.
The label includes input and output port columns with explicit anchors.
Output ports with direct feedthrough are highlighted in red; event ports
are highlighted in blue.
Args:
comp: Component to render.
execution_idx: Optional execution order index to display.
Returns:
Graphviz HTML-like label string.
"""
inputs = sorted(comp.input_specs.keys())
outputs = sorted(comp.output_specs.keys())
event_inputs = {p.name for p in comp.input_specs.values() if _is_event_port(p)}
event_outputs = {p.name for p in comp.output_specs.values() if _is_event_port(p)}
feedthrough_outputs = {name for name, deps in comp.direct_feedthrough.items() if deps}
header = "<TR><TD>in</TD><TD>name</TD><TD>out</TD></TR>"
badge = f"#{execution_idx}" if execution_idx >= 0 else ""
name_cell = f'<TD><B>{comp.name}</B><BR/><FONT POINT-SIZE="8">{badge}</FONT></TD>'
inputs_cell = _build_ports_column(inputs, event_inputs)
outputs_cell = _build_ports_column(outputs, event_outputs, feedthrough_outputs)
label = (
'<TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="4">'
f"{header}<TR>{inputs_cell}{name_cell}{outputs_cell}</TR></TABLE>"
)
return f"<{label}>"
[docs]
def _edge_unit_label(system: System, conn: Connection) -> str | None:
"""Return unit label from the destination port specification.
Args:
system: System containing the destination component.
conn: Connection whose destination port is inspected.
Returns:
Unit label string if present, otherwise None.
"""
dst_comp = system.components.get(conn.dst_comp)
if dst_comp is None:
return None
dst_ps = dst_comp.input_specs.get(conn.dst_port)
if dst_ps and dst_ps.unit:
return f"{dst_ps.unit}"
return None
[docs]
def _is_zero_delay_connection(
system: System, conn: Connection, active_outputs: dict[str, set[str]]
) -> bool:
"""Check if a connection participates in zero-delay direct feedthrough.
Args:
system: System containing the components.
conn: Connection to evaluate.
active_outputs: Mapping of components to output ports used by connections.
Returns:
True if this connection contributes to the zero-delay dependency graph.
"""
dst_comp = system.components.get(conn.dst_comp)
if dst_comp is None:
return False
relevant_outputs = active_outputs.get(conn.dst_comp, set())
for out_port, deps in dst_comp.direct_feedthrough.items():
if out_port not in relevant_outputs:
continue
if deps and conn.dst_port in deps:
return True
return False
[docs]
def _is_direct_feedthrough_output(comp: CoSimComponent, port_name: str) -> bool:
"""Return True if the given output port is direct-feedthrough.
Args:
comp: Component to inspect.
port_name: Output port name on the component.
Returns:
True if the output port has direct-feedthrough dependencies.
"""
deps = comp.direct_feedthrough.get(port_name)
return bool(deps)
[docs]
def _build_loop_index(system: System) -> dict[str, int]:
"""Build a component-to-loop index mapping.
Args:
system: System containing detected algebraic loops.
Returns:
Mapping from component name to algebraic loop index.
"""
loop_index: dict[str, int] = {}
for idx, loop in enumerate(system.algebraic_loops):
for name in loop:
loop_index[name] = idx
return loop_index
[docs]
def _legend_label() -> str:
"""Return the HTML label used for the legend node.
Returns:
Graphviz HTML-like label string.
"""
return """<
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="0">
<TR>
<TD ALIGN="LEFT" VALIGN="TOP">
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="1">
<TR><TD COLSPAN="2" ALIGN="LEFT"><FONT POINT-SIZE="8"><B>Edges:</B></FONT></TD></TR>
<TR>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="9">──▶</FONT></TD>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="8">Data flow</FONT></TD>
</TR>
<TR>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="9" COLOR="#ef4444"><B>━━▶</B></FONT></TD>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="8">Algebraic loop</FONT></TD>
</TR>
<TR>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="9" COLOR="#0b84ff"><B>╍╍○</B></FONT></TD>
<TD ALIGN="LEFT" VALIGN="MIDDLE"><FONT POINT-SIZE="8">Event connection</FONT></TD>
</TR>
</TABLE>
</TD>
<TD WIDTH="8"></TD>
<TD ALIGN="LEFT" VALIGN="TOP">
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0" CELLPADDING="1">
<TR><TD ALIGN="LEFT"><FONT POINT-SIZE="8"><B>Ports:</B></FONT></TD></TR>
<TR><TD ALIGN="LEFT"><FONT POINT-SIZE="8" COLOR="#ef4444"><B>Direct feedthrough</B></FONT></TD></TR>
<TR><TD ALIGN="LEFT"><FONT POINT-SIZE="8" COLOR="#0b84ff"><B>Event port</B></FONT></TD></TR>
</TABLE>
</TD>
</TR>
</TABLE>
>"""
# ----------------------------------------------------------------------------
# SystemGraphVisualizer class
# ----------------------------------------------------------------------------
[docs]
class SystemGraphVisualizer:
"""Render a System as a Graphviz directed graph."""
[docs]
def __init__(self, system: System):
"""Initialize the visualizer.
Args:
system: System instance to visualize.
"""
self.system = system
self.dot: Digraph | None = None
self._grouped_names: set[str] = set()
[docs]
def visualize(self, filename: str = "system_graph", format: str = "svg") -> None:
"""Build and render a Graphviz system graph.
This method constructs a new Graphviz Digraph, adds component nodes,
connection edges, and renders to disk without opening a viewer.
Args:
filename: Output filename; an extension is optional and will be stripped.
format: Graphviz output format (e.g., "svg", "png").
"""
self.dot = Digraph(
comment=f"System: {self.system.name}",
format=format,
engine="dot",
graph_attr={
"rankdir": "LR",
"splines": "true",
"pad": "0.5",
"nodesep": "0.8",
"ranksep": "0.8",
"esep": "0",
"sep": "0",
"fontsize": "18",
"fontname": FONT_BOLD,
"bgcolor": "white",
"margin": "0.2",
"labelloc": "t",
"label": self.system.name,
"compound": "true",
},
node_attr={
"shape": "plaintext",
"style": "",
"fillcolor": "white",
"fontname": FONT_DEFAULT,
"fontsize": "10",
"margin": "0.05,0.05",
},
edge_attr={
"arrowsize": "0.7",
"penwidth": "1.2",
"color": COLOR_EDGE,
"fontname": FONT_DEFAULT,
"fontsize": "10",
"labelfontcolor": COLOR_EDGE_LABEL,
"labelangle": "45",
"dir": "forward",
"arrowhead": "vee",
"arrowtail": "none",
"minlen": "1.2",
},
)
if self.system.connections and self.system._dag.number_of_nodes() == 0:
self.system.build_graphs()
active_outputs = collect_active_outputs(self.system)
loop_index = _build_loop_index(self.system)
palette = _build_palette(self.system)
self._add_legend()
self._add_grouped_components(palette)
self._add_ungrouped_components()
self._add_data_edges(active_outputs, loop_index)
self._add_event_edges()
render_name = filename.rsplit(".", 1)[0]
self.dot.render(filename=render_name, view=False)
display(self.dot)
[docs]
def save(self, filepath: str) -> None:
"""Save the current graph to a file.
Args:
filepath: Output path with extension (e.g., "graph.svg").
Raises:
RuntimeError: If visualize() was not called before saving.
"""
if not self.dot:
raise RuntimeError("Graph has not been visualized yet. Call visualize() before saving.")
self.dot.format = filepath.split(".")[-1]
filepath = filepath.rsplit(".", 1)[0]
self.dot.render(filename=filepath, view=False, cleanup=True)
[docs]
def _add_legend(self) -> None:
"""Add a legend subgraph to the current Graphviz graph."""
if not self.dot:
return
with self.dot.subgraph(name="cluster_legend") as legend:
legend.attr(
label="Legend",
style="rounded",
color=COLOR_LEGEND_BORDER,
bgcolor=COLOR_LEGEND_BG,
fontname=FONT_BOLD,
fontsize="9",
penwidth="1.0",
margin="3",
rank="sink",
)
legend.node("legend_node", label=_legend_label(), shape="plaintext", fontsize="8")
[docs]
def _add_grouped_components(self, palette: dict[str, str]) -> None:
"""Add grouped components as clustered subgraphs.
Args:
palette: Group name to color mapping.
"""
if not self.dot:
return
grouped_names: set[str] = set()
for group, comps in self.system.groups.items():
with self.dot.subgraph(name=f"cluster_{group}") as cluster:
cluster.attr(
label=group,
style="rounded",
color=COLOR_LEGEND_BORDER,
bgcolor=palette.get(group, "white"),
fontname=FONT_BOLD,
fontsize="11",
penwidth="2",
margin="12",
)
for comp in comps:
comp = (
comp
if isinstance(comp, CoSimComponent)
else self.system.components[str(comp)]
)
grouped_names.add(comp.name)
gen = self.system.execution_idx.get(comp.name, -1)
cluster.node(comp.name, label=_record_label_for_component(comp, gen))
self._grouped_names = grouped_names
[docs]
def _add_ungrouped_components(self) -> None:
"""Add any components that were not placed in a group cluster."""
if not self.dot:
return
grouped_names = self._grouped_names
for comp_name, comp in self.system.components.items():
if comp_name in grouped_names:
continue
gen = self.system.execution_idx.get(comp.name, -1)
self.dot.node(comp.name, label=_record_label_for_component(comp, gen))
[docs]
def _add_data_edges(
self, active_outputs: dict[str, set[str]], loop_index: dict[str, int]
) -> None:
"""Add standard data connection edges.
Args:
active_outputs: Mapping of component names to active output ports.
loop_index: Mapping of component names to algebraic loop index.
"""
if not self.dot:
return
for conn in self.system.connections:
src_anchor = f"{conn.src_port}:e"
dst_anchor = f"{conn.dst_port}:w"
label = _edge_unit_label(self.system, conn)
label = f"[{label}]" if label else ""
color = COLOR_EDGE
penwidth = "1.2"
src_loop = loop_index.get(conn.src_comp)
dst_loop = loop_index.get(conn.dst_comp)
src_comp = self.system.components.get(conn.src_comp)
dst_comp = self.system.components.get(conn.dst_comp)
if (
src_loop is not None
and src_loop == dst_loop
and src_comp is not None
and dst_comp is not None
and _is_direct_feedthrough_output(src_comp, conn.src_port)
and _is_direct_feedthrough_input(dst_comp, conn.dst_port)
and _is_zero_delay_connection(self.system, conn, active_outputs)
):
color = COLOR_FEEDTHROUGH
penwidth = "2.0"
self.dot.edge(
tail_name=conn.src_comp,
head_name=conn.dst_comp,
label=label,
tailport=src_anchor,
headport=dst_anchor,
tailclip="true",
headclip="true",
style="solid",
color=color,
penwidth=penwidth,
)
[docs]
def _add_event_edges(self) -> None:
"""Add event connection edges with a distinct style."""
if not self.dot:
return
for event_conn in self.system.event_connections:
src_anchor = f"{event_conn.src_port}:e"
dst_anchor = f"{event_conn.dst_port}:w"
self.dot.edge(
tail_name=event_conn.src_comp,
head_name=event_conn.dst_comp,
label="",
tailport=src_anchor,
headport=dst_anchor,
tailclip="true",
headclip="true",
style="dashed",
color=COLOR_EVENT,
penwidth="2.0",
arrowhead="odot",
fontcolor=COLOR_EVENT,
)