from __future__ import annotations
"""MTG visualization helpers.
The compact MTG view is a timeline diagnostic. Step panels reuse the molecule-
like ITS renderer so each reconstructed ITS step is inspected with the same
visual language as normal Lewis State Graph / ITS drawings.
"""
from typing import Any, Iterable, Mapping, Optional
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
import networkx as nx
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from synkit.Vis.its_drawer import draw_its_only
ELEMENT_COLORS = {
"H": "#ffffff",
"C": "#f8fafc",
"N": "#bfdbfe",
"O": "#fecaca",
"F": "#bbf7d0",
"Cl": "#bbf7d0",
"Br": "#fed7aa",
"I": "#ddd6fe",
"S": "#fde68a",
"P": "#fecdd3",
"B": "#e7e5e4",
"Si": "#e9d5ff",
}
EDGE_STYLES = {
"unchanged": ("#94a3b8", "solid", 1.7),
"formed": ("#15803d", "solid", 3.1),
"broken": ("#b91c1c", "solid", 3.1),
"transient": ("#ec4899", "dashed", 3.0),
}
[docs]
def draw_mtg_graph(
mtg: Any,
*,
ax: Optional[plt.Axes] = None,
title: Optional[str] = None,
mode: str = "timeline",
layout: str = "kamada_kawai",
show_atom_map: bool = True,
show_edge_labels: bool = True,
show_node_badges: bool = False,
hydrogen_mode: str = "changed",
changed_only: bool = False,
compress: bool = True,
show_step_axis: bool = False,
dimension: str = "2d",
seed: int = 7,
) -> tuple[plt.Figure, plt.Axes]:
"""Draw a compact MTG timeline graph.
``mtg`` may be a :class:`synkit.Graph.MTG.mtg.MTG` instance or a raw
compact MTG ``networkx.Graph`` from ``MTG.get_mtg()``.
:param mtg: MTG object or compact MTG graph.
:type mtg: Any
:param ax: Optional Matplotlib axes.
:type ax: Optional[plt.Axes]
:param title: Optional title.
:type title: Optional[str]
:param mode: Label mode. ``"timeline"`` is the recommended MTG view;
``"sigma_pi"`` gives a shorter Lewis-state bond diagnostic when
sigma/pi timelines are present.
:type mode: str
:param layout: NetworkX layout name: ``"kamada_kawai"``, ``"spring"``,
``"circular"``, or ``"shell"``.
:type layout: str
:param hydrogen_mode: Hydrogen display policy. ``"changed"`` keeps only
hydrogens participating in changing edges, ``"all"`` keeps all, and
``"none"`` hides all hydrogens.
:type hydrogen_mode: str
:param changed_only: If True, hide unchanged edges and isolated nodes.
:type changed_only: bool
:param compress: If True, edge labels show only first and final state.
If False, edge labels show the full mechanism-state timeline.
:type compress: bool
:param show_step_axis: Draw a compact state axis under the graph.
:type show_step_axis: bool
:param dimension: Draw as ``"2d"`` or ``"3d"``. The 3D mode uses a
spring layout with ``dim=3`` and is helpful for dense changed cores.
:type dimension: str
:returns: ``(figure, axes)``.
:rtype: tuple[plt.Figure, plt.Axes]
"""
if dimension not in {"2d", "3d"}:
raise ValueError("dimension must be '2d' or '3d'")
graph = _as_mtg_graph(mtg)
display = _mtg_display_graph(
graph,
mode=mode,
show_atom_map=show_atom_map,
show_node_badges=show_node_badges,
hydrogen_mode=hydrogen_mode,
changed_only=changed_only,
compress=compress,
)
return _draw_mtg_display(
display,
ax=ax,
title=title or "MTG timeline",
layout=layout,
show_edge_labels=show_edge_labels,
show_step_axis=show_step_axis,
dimension=dimension,
seed=seed,
)
[docs]
def draw_mtg_steps(
mtg: Any,
*,
steps: Optional[Iterable[int]] = None,
include_composed: bool = False,
title: Optional[str] = None,
max_columns: int = 3,
show_atom_map: bool = True,
label_mode: str = "hetero",
edge_label_mode: str = "kekule",
show_edge_labels: bool = False,
show_electron_labels: bool = False,
electron_label_mode: str = "charge",
) -> tuple[plt.Figure, list[plt.Axes]]:
"""Draw reconstructed MTG ITS steps as ordered panels.
:param mtg: MTG object exposing ``get_its_steps``.
:type mtg: Any
:param steps: Optional zero-based step indices to draw.
:type steps: Optional[Iterable[int]]
:param include_composed: Append the composed outer-state ITS panel.
:type include_composed: bool
:param title: Optional figure title.
:type title: Optional[str]
:param max_columns: Maximum subplot columns.
:type max_columns: int
:returns: ``(figure, axes)``.
:rtype: tuple[plt.Figure, list[plt.Axes]]
"""
if not hasattr(mtg, "get_its_steps"):
raise TypeError("draw_mtg_steps expects an MTG object with get_its_steps().")
all_steps = list(mtg.get_its_steps())
selected = list(range(len(all_steps))) if steps is None else list(steps)
for step in selected:
if step < 0 or step >= len(all_steps):
raise IndexError(f"MTG step index out of range: {step}")
panels = [(f"Step {step + 1}", all_steps[step]) for step in selected]
if include_composed:
if not hasattr(mtg, "get_compose_its"):
raise TypeError(
"include_composed requires an MTG object with get_compose_its()."
)
panels.append(("Composed", mtg.get_compose_its()))
if not panels:
raise ValueError("No MTG steps selected for drawing.")
ncols = min(max(1, max_columns), len(panels))
nrows = (len(panels) + ncols - 1) // ncols
fig, axes_grid = plt.subplots(
nrows,
ncols,
figsize=(4.8 * ncols, 4.2 * nrows),
squeeze=False,
facecolor="white",
)
axes = [ax for row in axes_grid for ax in row]
if title:
fig.suptitle(title, fontsize=13, fontweight="bold")
for ax, (panel_title, its) in zip(axes, panels):
draw_its_only(
its,
ax=ax,
title=panel_title,
show_atom_map=show_atom_map,
label_mode=label_mode,
edge_label_mode=edge_label_mode,
show_edge_labels=show_edge_labels,
show_electron_labels=show_electron_labels,
electron_label_mode=electron_label_mode,
)
for ax in axes[len(panels) :]: # noqa
ax.set_axis_off()
fig.tight_layout()
return fig, axes[: len(panels)]
def _as_mtg_graph(mtg: Any) -> nx.Graph:
if isinstance(mtg, nx.Graph):
return mtg
if hasattr(mtg, "get_mtg"):
graph = mtg.get_mtg()
if isinstance(graph, nx.Graph):
return graph
raise TypeError("Expected an MTG object or a NetworkX compact MTG graph.")
def _mtg_display_graph(
graph: nx.Graph,
*,
mode: str,
show_atom_map: bool,
show_node_badges: bool,
hydrogen_mode: str,
changed_only: bool,
compress: bool,
) -> nx.Graph:
if hydrogen_mode not in {"changed", "all", "none"}:
raise ValueError("hydrogen_mode must be one of: changed, all, none")
edge_info = {
_edge_key(u, v): _edge_visual(attrs, mode=mode, compress=compress)
for u, v, attrs in graph.edges(data=True)
}
changed_incident = {
node
for (u, v), info in edge_info.items()
if info["state"] != "unchanged"
for node in (u, v)
}
display = nx.Graph()
for node, attrs in graph.nodes(data=True):
element = str(_first_present(attrs.get("element")) or "")
atom_map = _first_present(attrs.get("atom_map"))
if element == "H":
if hydrogen_mode == "none":
continue
if hydrogen_mode == "changed" and atom_map in (None, 0):
continue
if hydrogen_mode == "changed" and node not in changed_incident:
continue
if changed_only and node not in changed_incident:
continue
label = _node_label(node, attrs, show_atom_map=show_atom_map)
badges = _node_badges(attrs) if show_node_badges else []
display.add_node(
node,
label=label,
badges=tuple(badges),
element=element,
changed=bool(badges) or node in changed_incident,
fill=ELEMENT_COLORS.get(element, "#f3f4f6"),
)
for u, v, attrs in graph.edges(data=True):
key = _edge_key(u, v)
info = edge_info[key]
if changed_only and info["state"] == "unchanged":
continue
if u not in display or v not in display:
continue
display.add_edge(u, v, **info, raw=dict(attrs))
display.graph["steps"] = _infer_state_count(graph)
return display
def _draw_mtg_display(
graph: nx.Graph,
*,
ax: Optional[plt.Axes],
title: str,
layout: str,
show_edge_labels: bool,
show_step_axis: bool,
dimension: str,
seed: int,
) -> tuple[plt.Figure, plt.Axes]:
if ax is None:
fig = plt.figure(figsize=_figure_size(graph), facecolor="white")
ax = (
fig.add_subplot(111, projection="3d")
if dimension == "3d"
else fig.add_subplot(111)
)
else:
fig = ax.figure
pos = _layout(graph, layout=layout, dimension=dimension, seed=seed)
ax.clear()
ax.set_axis_off()
if dimension == "2d":
ax.set_aspect("equal")
ax.set_title(title, fontsize=13, fontweight="bold", pad=12)
if dimension == "3d":
_draw_mtg_display_3d(
graph,
pos,
ax=ax,
show_edge_labels=show_edge_labels,
)
_draw_legend(ax)
fig.tight_layout()
return fig, ax
for state in ("unchanged", "formed", "broken", "transient"):
edges = [
(u, v)
for u, v, attrs in graph.edges(data=True)
if attrs.get("state") == state
]
if not edges:
continue
color, style, width = EDGE_STYLES[state]
nx.draw_networkx_edges(
graph,
pos,
ax=ax,
edgelist=edges,
edge_color=color,
style=style,
width=width,
alpha=0.88 if state != "unchanged" else 0.38,
)
nodes = list(graph.nodes(data=True))
if nodes:
nx.draw_networkx_nodes(
graph,
pos,
ax=ax,
node_color=[attrs["fill"] for _, attrs in nodes],
edgecolors=[
"#f97316" if attrs.get("changed") else "#475569" for _, attrs in nodes
],
linewidths=[2.6 if attrs.get("changed") else 1.2 for _, attrs in nodes],
node_size=[
760 if attrs.get("element") != "H" else 500 for _, attrs in nodes
],
)
nx.draw_networkx_labels(
graph,
pos,
labels={
node: _stack_node_label(attrs) for node, attrs in graph.nodes(data=True)
},
ax=ax,
font_size=8,
font_weight="bold",
font_color="#111827",
)
if show_edge_labels:
edge_labels = {
(u, v): attrs["label"]
for u, v, attrs in graph.edges(data=True)
if attrs.get("label")
}
if edge_labels:
nx.draw_networkx_edge_labels(
graph,
pos,
edge_labels=edge_labels,
ax=ax,
font_size=7,
rotate=False,
font_color="#111827",
bbox={
"boxstyle": "round,pad=0.18",
"fc": "white",
"ec": "#cbd5e1",
"alpha": 0.94,
},
)
_draw_legend(ax)
if show_step_axis:
_draw_step_axis(ax, graph.graph.get("steps", 0))
_pad_limits(ax, pos)
fig.tight_layout()
return fig, ax
def _draw_mtg_display_3d(
graph: nx.Graph,
pos: Mapping[Any, Any],
*,
ax: plt.Axes,
show_edge_labels: bool,
) -> None:
for state in ("unchanged", "formed", "broken", "transient"):
color, style, width = EDGE_STYLES[state]
alpha = 0.88 if state != "unchanged" else 0.28
for u, v, attrs in graph.edges(data=True):
if attrs.get("state") != state:
continue
p0 = pos[u]
p1 = pos[v]
ax.plot(
[p0[0], p1[0]],
[p0[1], p1[1]],
[p0[2], p1[2]],
color=color,
linestyle=style,
linewidth=width,
alpha=alpha,
)
if show_edge_labels and attrs.get("label"):
mid = ((p0[0] + p1[0]) / 2, (p0[1] + p1[1]) / 2, (p0[2] + p1[2]) / 2)
ax.text(
*mid,
attrs["label"],
fontsize=7,
color="#111827",
ha="center",
va="center",
)
for node, attrs in graph.nodes(data=True):
x, y, z = pos[node]
edge_color = "#f97316" if attrs.get("changed") else "#475569"
size = 430 if attrs.get("element") != "H" else 320
ax.scatter(
[x],
[y],
[z],
s=size,
c=[attrs["fill"]],
edgecolors=[edge_color],
linewidths=1.5,
depthshade=True,
)
ax.text(
x,
y,
z + 0.12,
_stack_node_label(attrs),
fontsize=8.5,
fontweight="bold",
color="#111827",
ha="center",
va="center",
bbox={
"boxstyle": "round,pad=0.08",
"fc": "white",
"ec": "none",
"alpha": 0.78,
},
)
def _edge_visual(
attrs: Mapping[str, Any],
*,
mode: str,
compress: bool,
) -> dict[str, Any]:
preferred = _preferred_timeline(attrs, mode=mode)
state = _timeline_state(preferred)
label = _timeline_label(
attrs,
preferred,
mode=mode,
state=state,
compress=compress,
)
color, style, width = EDGE_STYLES[state]
return {
"history": tuple(preferred),
"state": state,
"label": label,
"color": color,
"style": style,
"width": width,
}
def _preferred_timeline(attrs: Mapping[str, Any], *, mode: str) -> tuple[Any, ...]:
if mode == "sigma_pi":
sigma = _coerce_timeline(attrs.get("sigma_order"))
pi = _coerce_timeline(attrs.get("pi_order"))
if _changes(sigma) or _changes(pi):
return tuple(
None if s is None and p is None else _none_order(s) + _none_order(p)
for s, p in zip(_pad(sigma, pi), _pad(pi, sigma))
)
for key in ("kekule_order", "order", "sigma_order", "pi_order"):
timeline = _coerce_timeline(attrs.get(key))
if timeline:
return timeline
return ()
def _timeline_label(
attrs: Mapping[str, Any],
preferred: tuple[Any, ...],
*,
mode: str,
state: str,
compress: bool,
) -> str:
if state == "unchanged":
return ""
timeline = _compressed_timeline(preferred) if compress else preferred
if mode == "sigma_pi":
parts = []
for key, prefix in (("sigma_order", "σ"), ("pi_order", "π")):
part_timeline = _coerce_timeline(attrs.get(key))
if part_timeline and _changes(_known_timeline(part_timeline)):
part_timeline = (
_compressed_timeline(part_timeline) if compress else part_timeline
)
parts.append(f"{prefix}:{_format_timeline(part_timeline)}")
if parts:
return " ".join(parts)
return _format_timeline(timeline)
def _coerce_timeline(value: Any) -> tuple[Any, ...]:
if not isinstance(value, tuple):
return ()
if value and all(_is_step_pair(item) for item in value):
history = []
for idx, pair in enumerate(value):
left, right = pair
if idx == 0:
history.append(_clean_order(left))
history.append(_clean_order(right))
return tuple(history)
if value and not any(isinstance(item, (tuple, list, dict, set)) for item in value):
return value
return ()
def _is_step_pair(value: Any) -> bool:
return isinstance(value, tuple) and len(value) == 2
def _clean_order(value: Any) -> Any:
if isinstance(value, set):
return None
return value
def _timeline_state(timeline: tuple[Any, ...]) -> str:
known = _known_timeline(timeline)
numeric = [_none_order(value) for value in known]
if not numeric or len(set(numeric)) == 1:
return "unchanged"
if numeric[0] == numeric[-1]:
return "transient"
if numeric[0] == 0 and numeric[-1] > 0:
return "formed"
if numeric[0] > 0 and numeric[-1] == 0:
return "broken"
return "transient"
def _node_label(
node: Any,
attrs: Mapping[str, Any],
*,
show_atom_map: bool,
) -> str:
element = _first_present(attrs.get("element")) or str(node)
atom_map = _first_present(attrs.get("atom_map"))
if show_atom_map and atom_map not in (None, 0):
return f"{element}:{atom_map}"
if show_atom_map:
return f"{element}:{node}"
return str(element)
def _node_badges(attrs: Mapping[str, Any]) -> list[str]:
badges = []
for key, label in (
("charge", "q"),
("hcount", "H"),
("lone_pairs", "lp"),
("radical", "rad"),
):
timeline = _coerce_node_timeline(attrs.get(key))
if timeline and _changes(timeline):
badges.append(f"{label}:{_format_timeline(timeline)}")
return badges[:2]
def _coerce_node_timeline(value: Any) -> tuple[Any, ...]:
if (
isinstance(value, tuple)
and value
and all(_is_step_pair(item) for item in value)
):
return _coerce_timeline(value)
if isinstance(value, tuple) and len(value) >= 3:
return value
if isinstance(value, tuple) and len(value) == 2:
return value
return ()
def _stack_node_label(attrs: Mapping[str, Any]) -> str:
label = str(attrs.get("label", ""))
badges = attrs.get("badges") or ()
return f"{label}\n{' '.join(badges)}" if badges else label
def _format_timeline(timeline: tuple[Any, ...]) -> str:
return "→".join(_format_order(value) for value in timeline)
def _compressed_timeline(timeline: tuple[Any, ...]) -> tuple[Any, ...]:
if len(timeline) <= 2:
return timeline
return (timeline[0], timeline[-1])
def _trim_timeline(timeline: tuple[Any, ...]) -> tuple[Any, ...]:
if len(timeline) <= 2:
return timeline
start = 0
end = len(timeline)
while start + 1 < end and timeline[start] == timeline[start + 1]:
start += 1
while end - 2 >= start and timeline[end - 1] == timeline[end - 2]:
end -= 1
return timeline[start:end]
def _format_order(value: Any) -> str:
if value is None:
return "∅"
if isinstance(value, float) and value.is_integer():
return str(int(value))
return str(value)
def _none_order(value: Any) -> float:
return 0.0 if value is None else float(value)
def _changes(timeline: tuple[Any, ...]) -> bool:
return bool(timeline) and len(set(timeline)) > 1
def _known_timeline(timeline: tuple[Any, ...]) -> tuple[Any, ...]:
start = 0
end = len(timeline)
while start < end and timeline[start] is None:
start += 1
while end > start and timeline[end - 1] is None:
end -= 1
return timeline[start:end]
def _pad(first: tuple[Any, ...], second: tuple[Any, ...]) -> tuple[Any, ...]:
if len(first) >= len(second):
return first
return first + (None,) * (len(second) - len(first))
def _first_present(value: Any) -> Any:
if isinstance(value, tuple):
for item in value:
if isinstance(item, tuple):
for side in item:
if side not in (None, set()):
return side
elif item is not None:
return item
return None
return value
def _edge_key(u: Any, v: Any) -> tuple[Any, Any]:
return (u, v) if str(u) <= str(v) else (v, u)
def _infer_state_count(graph: nx.Graph) -> int:
max_len = 0
for _, _, attrs in graph.edges(data=True):
for key in ("kekule_order", "order", "sigma_order", "pi_order"):
max_len = max(max_len, len(_coerce_timeline(attrs.get(key))))
return max_len
def _layout(
graph: nx.Graph,
*,
layout: str,
dimension: str,
seed: int,
) -> dict[Any, Any]:
if graph.number_of_nodes() == 0:
return {}
if dimension == "3d":
if layout not in {"spring", "kamada_kawai"}:
raise ValueError("3D MTG layout supports: spring, kamada_kawai")
return nx.spring_layout(graph, seed=seed, k=1.15, iterations=160, dim=3)
if layout == "spring":
return nx.spring_layout(graph, seed=seed, k=1.15, iterations=120)
if layout == "kamada_kawai":
return nx.kamada_kawai_layout(graph)
if layout == "circular":
return nx.circular_layout(graph)
if layout == "shell":
return nx.shell_layout(graph)
raise ValueError("layout must be one of: spring, kamada_kawai, circular, shell")
def _figure_size(graph: nx.Graph) -> tuple[float, float]:
n_nodes = max(1, graph.number_of_nodes())
return min(14.0, max(7.0, n_nodes * 0.78)), min(10.0, max(5.2, n_nodes * 0.55))
def _draw_legend(ax: plt.Axes) -> None:
handles = [
mlines.Line2D(
[], [], color=color, linestyle=style, linewidth=width, label=label
)
for label, (color, style, width) in (
("formed", EDGE_STYLES["formed"]),
("broken", EDGE_STYLES["broken"]),
("transient", EDGE_STYLES["transient"]),
)
]
ax.legend(
handles=handles,
loc="upper right",
bbox_to_anchor=(1.0, 1.0),
frameon=False,
fontsize=8,
ncol=1,
)
def _draw_step_axis(ax: plt.Axes, states: int) -> None:
if states <= 1:
return
text = "states " + " → ".join(f"S{i}" for i in range(states))
ax.text(
0.5,
-0.045,
text,
transform=ax.transAxes,
ha="center",
va="top",
fontsize=8,
color="#475569",
)
def _pad_limits(ax: plt.Axes, pos: Mapping[Any, Any]) -> None:
if not pos:
return
xs = [point[0] for point in pos.values()]
ys = [point[1] for point in pos.values()]
x_span = max(xs) - min(xs)
y_span = max(ys) - min(ys)
pad = max(x_span, y_span, 1.0) * 0.25
ax.set_xlim(min(xs) - pad, max(xs) + pad)
ax.set_ylim(min(ys) - pad, max(ys) + pad)