from __future__ import annotations
"""Matplotlib drawing helpers for representation-aware SynKit visuals."""
from typing import Any, Mapping
import matplotlib.pyplot as plt
import networkx as nx
from synkit.Vis.visual_model import VisualGraph, to_visual_graph
ELEMENT_COLORS = {
"H": "#ffffff",
"C": "#f8fafc",
"N": "#bfdbfe",
"O": "#fecaca",
"F": "#bbf7d0",
"Cl": "#bbf7d0",
"Br": "#fed7aa",
"I": "#ddd6fe",
"S": "#fde68a",
"P": "#fecdd3",
"B": "#e7e5e4",
"Si": "#e9d5ff",
}
[docs]
def draw_graph(
graph: nx.Graph | VisualGraph,
*,
ax: plt.Axes | None = None,
mode: str = "compact",
title: str | None = None,
show_atom_map: bool = True,
layout: str = "spring",
pos: Mapping[Any, tuple[float, float]] | None = None,
seed: int = 7,
node_size: int = 980,
font_size: int = 9,
edge_label_font_size: int = 8,
show_edge_labels: bool = True,
show_node_badges: bool = True,
) -> tuple[plt.Figure, plt.Axes]:
"""Draw a molecule, ITS, or MTG graph using the visual adapter.
:param graph: Raw NetworkX graph or already adapted ``VisualGraph``.
:type graph: Union[nx.Graph, VisualGraph]
:param ax: Optional Matplotlib axes.
:type ax: Optional[plt.Axes]
:param mode: Adapter label mode, e.g. ``compact``, ``electron``,
``sigma_pi``, or ``timeline``.
:type mode: str
:param title: Optional title. Defaults to the detected visual kind.
:type title: Optional[str]
:param show_atom_map: Include atom maps in labels when adapting raw graphs.
:type show_atom_map: bool
:param layout: Layout name: ``spring``, ``kamada_kawai``, ``circular``, or
``shell``.
:type layout: str
:param pos: Optional fixed positions.
:type pos: Optional[Mapping[Any, Tuple[float, float]]]
:returns: ``(figure, axes)``.
:rtype: Tuple[plt.Figure, plt.Axes]
"""
visual = (
graph
if isinstance(graph, VisualGraph)
else to_visual_graph(
graph,
mode=mode, # type: ignore[arg-type]
show_atom_map=show_atom_map,
title=title or "",
)
)
nx_graph = _to_nx_graph(visual)
if ax is None:
fig, ax = plt.subplots(figsize=_figure_size(nx_graph))
else:
fig = ax.figure
if pos is None:
pos = _layout(nx_graph, layout=layout, seed=seed)
ax.clear()
ax.set_axis_off()
ax.set_aspect("equal")
ax.set_title(title or visual.title or visual.kind, fontsize=12, fontweight="bold")
edges = list(nx_graph.edges(data=True))
nodes = list(nx_graph.nodes(data=True))
if edges:
nx.draw_networkx_edges(
nx_graph,
pos,
ax=ax,
edge_color=[data["visual_color"] for _, _, data in edges],
width=[data["visual_width"] for _, _, data in edges],
alpha=0.88,
)
node_collection = nx.draw_networkx_nodes(
nx_graph,
pos,
ax=ax,
node_color=[data["fill"] for _, data in nodes],
edgecolors=[data["border"] for _, data in nodes],
linewidths=[2.4 if data["changed"] else 1.2 for _, data in nodes],
node_size=node_size,
)
node_collection.set_zorder(3)
labels = {
node: _node_label(data, show_node_badges=show_node_badges)
for node, data in nx_graph.nodes(data=True)
}
nx.draw_networkx_labels(
nx_graph,
pos,
labels=labels,
ax=ax,
font_size=font_size,
font_color="#111827",
)
if show_edge_labels:
edge_labels = {
(u, v): data["label"]
for u, v, data in nx_graph.edges(data=True)
if data.get("label")
}
if edge_labels:
nx.draw_networkx_edge_labels(
nx_graph,
pos,
edge_labels=edge_labels,
ax=ax,
font_size=edge_label_font_size,
font_color="#111827",
bbox={
"boxstyle": "round,pad=0.18",
"fc": "white",
"ec": "#d1d5db",
"alpha": 0.92,
},
)
_pad_limits(ax, pos)
return fig, ax
def _to_nx_graph(visual: VisualGraph) -> nx.Graph:
graph = nx.Graph()
for node in visual.nodes:
graph.add_node(
node.node_id,
label=node.label,
badges=node.badges,
changed=node.changed,
fill=ELEMENT_COLORS.get(node.element or "", "#f3f4f6"),
border="#dc2626" if node.changed else "#374151",
)
for edge in visual.edges:
graph.add_edge(
edge.source,
edge.target,
label=edge.label,
state=edge.state,
visual_color=edge.color,
visual_width=edge.width,
)
return graph
def _node_label(data: Mapping[str, Any], *, show_node_badges: bool) -> str:
label = str(data.get("label", ""))
badges = data.get("badges") or ()
if show_node_badges and badges:
return f"{label}\n{' '.join(badges)}"
return label
def _layout(graph: nx.Graph, *, layout: str, seed: int) -> dict[Any, Any]:
if graph.number_of_nodes() == 0:
return {}
if layout == "spring":
return nx.spring_layout(graph, seed=seed, k=1.1)
if layout == "kamada_kawai":
return nx.kamada_kawai_layout(graph)
if layout == "circular":
return nx.circular_layout(graph)
if layout == "shell":
return nx.shell_layout(graph)
raise ValueError("layout must be one of: spring, kamada_kawai, circular, shell")
def _figure_size(graph: nx.Graph) -> tuple[float, float]:
n_nodes = max(1, graph.number_of_nodes())
width = min(12.0, max(4.8, 1.25 * n_nodes))
height = min(8.0, max(3.6, 0.85 * n_nodes))
return width, height
def _pad_limits(ax: plt.Axes, pos: Mapping[Any, Any]) -> None:
if not pos:
return
xs = [point[0] for point in pos.values()]
ys = [point[1] for point in pos.values()]
x_span = max(xs) - min(xs)
y_span = max(ys) - min(ys)
pad = max(x_span, y_span, 1.0) * 0.18
ax.set_xlim(min(xs) - pad, max(xs) + pad)
ax.set_ylim(min(ys) - pad, max(ys) + pad)