Source code for synkit.Vis.its_drawer

from __future__ import annotations

"""ITS visualization.

The default ITS view is a single molecule-like transition graph.  Reactant /
product molecular projections remain available through ``projection=True`` for
debugging and comparison.
"""

from typing import Any, Optional, Tuple

import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import networkx as nx

from synkit.Graph.ITS.its_decompose import its_decompose
from synkit.Graph.ITS.its_reverter import ITSReverter
from synkit.IO.chem_converter import rsmi_to_its
from synkit.Vis.molecule_drawer import (
    _draw_aromatic_circles,
    _draw_bond_lines,
    _edge_is_aromatic,
    _element_colors,
    _element_label,
    _index_offset_vec,
    _layout_positions,
    _luminance,
    _set_padded_limits,
)
from synkit.Vis.reaction_drawer import draw_reaction_graphs, find_reaction_highlights
from synkit.Vis.visual_drawer import draw_graph


[docs] def draw_its_graph( its: nx.Graph, *, title: Optional[str] = None, mode: str = "sigma_pi", show_atom_map: bool = True, label_mode: str = "hetero", aromatic_style: str = "circle", include_delta_panel: bool = True, projection: bool = False, show_edge_labels: bool = False, edge_label_mode: str = "kekule", show_electron_labels: bool = False, electron_label_mode: str = "charge", ) -> tuple[plt.Figure, list[plt.Axes]]: """Draw an ITS graph. By default this draws only the ITS as a molecule-like graph. Changed bonds are colored and compactly labeled from ``kekule_order``. Optional node electron labels can show one of charge, lone-pair, radical, or all changes. Set ``projection=True`` to draw reactant/product molecular projections plus a diagnostic ITS panel. :param its: ITS graph in tuple or legacy representation. :type its: nx.Graph :param title: Optional figure title. :type title: Optional[str] :param mode: Diagnostic label mode for the projection-mode delta panel. :type mode: str :param show_atom_map: Show atom-map labels. :type show_atom_map: bool :param label_mode: Atom label mode. :type label_mode: str :param aromatic_style: Aromatic style for molecular panels. :type aromatic_style: str :param include_delta_panel: In projection mode, include a diagnostic ITS graph panel. :type include_delta_panel: bool :param projection: If ``True``, draw reactant/product molecular projection panels plus an ITS delta panel. If ``False``, draw only the ITS graph. :type projection: bool :param show_edge_labels: If ``True``, show labels for unchanged edges too. Changed edge labels are shown by default unless ``edge_label_mode`` is ``"none"``. :type show_edge_labels: bool :param edge_label_mode: ``"kekule"``, ``"sigma_pi"``, or ``"none"``. :type edge_label_mode: str :param show_electron_labels: Show changed atom electron annotations. :type show_electron_labels: bool :param electron_label_mode: ``"charge"``, ``"lone_pair"``, ``"radical"``, or ``"all"``. :type electron_label_mode: str :returns: ``(fig, axes)``. :rtype: tuple[plt.Figure, list[plt.Axes]] """ if not projection: fig, ax = plt.subplots(figsize=(7.0, 5.0), facecolor="white") draw_its_only( its, ax=ax, title=title or "ITS", show_atom_map=show_atom_map, label_mode=label_mode, aromatic_style=aromatic_style, show_edge_labels=show_edge_labels, edge_label_mode=edge_label_mode, show_electron_labels=show_electron_labels, electron_label_mode=electron_label_mode, ) fig.tight_layout() return fig, [ax] reactant, product = _its_to_side_graphs(its) if not include_delta_panel: return draw_reaction_graphs( reactant, product, title=title or "ITS projections", show_atom_map=show_atom_map, highlight_reaction_center=True, label_mode=label_mode, aromatic_style=aromatic_style, ) n_reaction_axes = ( nx.number_connected_components(reactant) + nx.number_connected_components(product) + 1 ) fig = plt.figure( figsize=(max(10.0, 3.1 * (n_reaction_axes + 1)), 3.7), facecolor="white", ) grid = fig.add_gridspec( 1, n_reaction_axes + 1, width_ratios=[1.0] * n_reaction_axes + [1.35], ) axes = [fig.add_subplot(grid[0, index]) for index in range(n_reaction_axes + 1)] # Draw panels directly here so the diagnostic ITS delta can share one # figure with the molecular projections. from synkit.Vis.reaction_drawer import _components, _draw_arrow, _draw_part highlights = find_reaction_highlights(reactant, product) panel = 0 for index, part in enumerate(_components(reactant)): _draw_part( part, axes[panel], title="Reactant" if index == 0 else "+", highlights=highlights, side="reactant", show_atom_map=show_atom_map, label_mode=label_mode, aromatic_style=aromatic_style, ) panel += 1 _draw_arrow(axes[panel]) panel += 1 for index, part in enumerate(_components(product)): _draw_part( part, axes[panel], title="Product" if index == 0 else "+", highlights=highlights, side="product", show_atom_map=show_atom_map, label_mode=label_mode, aromatic_style=aromatic_style, ) panel += 1 draw_graph( its, ax=axes[-1], mode=mode, title="ITS delta", show_atom_map=show_atom_map, layout="kamada_kawai", ) if title: fig.suptitle(title, fontsize=12, fontweight="bold", y=0.98) fig.tight_layout() return fig, axes
[docs] def draw_its_from_rsmi( rsmi: str, *, format: str = "tuple", core: bool = False, title: Optional[str] = None, mode: str = "sigma_pi", show_atom_map: bool = True, label_mode: str = "hetero", aromatic_style: str = "circle", include_delta_panel: bool = True, projection: bool = False, show_edge_labels: bool = False, edge_label_mode: str = "kekule", show_electron_labels: bool = False, electron_label_mode: str = "charge", ) -> tuple[plt.Figure, list[plt.Axes]]: """Build an ITS from RSMI and draw it.""" its = rsmi_to_its(rsmi, core=core, format=format) return draw_its_graph( its, title=title or "ITS from RSMI", mode=mode, show_atom_map=show_atom_map, label_mode=label_mode, aromatic_style=aromatic_style, include_delta_panel=include_delta_panel, projection=projection, show_edge_labels=show_edge_labels, edge_label_mode=edge_label_mode, show_electron_labels=show_electron_labels, electron_label_mode=electron_label_mode, )
[docs] def draw_its_only( # noqa: C901 its: nx.Graph, *, ax: Optional[plt.Axes] = None, title: Optional[str] = None, show_atom_map: bool = True, label_mode: str = "hetero", aromatic_style: str = "circle", show_edge_labels: bool = False, edge_label_mode: str = "kekule", show_electron_labels: bool = False, electron_label_mode: str = "charge", ) -> plt.Axes: """Draw a molecule-like ITS transition graph on one axes.""" edge_label_mode = edge_label_mode.lower() if edge_label_mode not in {"none", "kekule", "sigma_pi"}: raise ValueError("edge_label_mode must be one of: none, kekule, sigma_pi") electron_label_mode = electron_label_mode.lower() if electron_label_mode not in {"charge", "lone_pair", "radical", "all"}: raise ValueError( "electron_label_mode must be one of: charge, lone_pair, radical, all" ) display = _its_display_graph(its) fig = None if ax is None: fig, ax = plt.subplots(figsize=(7.0, 5.0), facecolor="white") else: fig = ax.figure ax.clear() ax.set_facecolor("white") ax.set_axis_off() ax.set_aspect("equal") nodes = list(display.nodes()) pos = _layout_positions(display, nodes, use_h_count=False) avg_len = _avg_edge_length(pos, display) bond_offset = avg_len * 0.09 atom_map_offset = avg_len * 0.18 n_nodes = max(1, len(nodes)) node_size = max(210, min(560, 5200 // n_nodes)) bond_width = max(1.5, min(2.8, 26 / n_nodes)) element_font_size = max(7, min(12, 100 // n_nodes)) atom_map_font_size = max(7, element_font_size) for u, v, attrs in display.edges(data=True): p1, p2 = pos[u], pos[v] state = attrs.get("its_state", "unchanged") order = attrs.get("display_order", 1.0) aromatic = bool(attrs.get("display_aromatic", False)) color = _state_color(state) if state in {"formed", "broken"}: ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], color=color, linewidth=bond_width * 3.6, alpha=0.18, solid_capstyle="round", zorder=1, ) _draw_bond_lines( ax, p1, p2, order=max(1, int(round(order))), aromatic=aromatic, aromatic_style=aromatic_style, offset=bond_offset, lw=bond_width if state == "unchanged" else bond_width * 1.25, color=color, ) if state in {"formed", "broken"}: line_style = (0, (3, 3)) ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], color=color, linewidth=bond_width * 1.6, linestyle=line_style, alpha=0.95, solid_capstyle="round", zorder=3, ) edge_label = attrs.get(f"its_label_{edge_label_mode}", "") if ( edge_label_mode != "none" and (show_edge_labels or state != "unchanged") and edge_label ): ax.text( (p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2, edge_label, fontsize=7, ha="center", va="center", color="#111827", bbox={ "boxstyle": "round,pad=0.12", "fc": "white", "ec": "none", "alpha": 0.9, }, zorder=9, ) if aromatic_style == "circle": _draw_aromatic_circles(ax, display, pos, scale=0.52) node_colors = [] node_borders = [] for node in nodes: fill, border = _element_colors(str(display.nodes[node].get("element", "C"))) if display.nodes[node].get("its_changed", False): border = "#f97316" node_colors.append(fill) node_borders.append(border) node_artist = nx.draw_networkx_nodes( display, pos, nodelist=nodes, node_color=node_colors, edgecolors=node_borders, linewidths=[ ( max(2.2, node_size**0.5 * 0.1) if display.nodes[node].get("its_changed", False) else max(1.0, node_size**0.5 * 0.065) ) for node in nodes ], node_size=node_size, ax=ax, ) node_artist.set_zorder(4) for node in nodes: attrs = display.nodes[node] text = _element_label(attrs, label_mode=label_mode) if text: x, y = pos[node] fill, _ = _element_colors(str(attrs.get("element", "C"))) ax.text( x, y, text, ha="center", va="center", fontsize=element_font_size, fontweight="bold", color="white" if _luminance(fill) < 0.5 else "#1f2937", zorder=10, ) if show_atom_map: atom_map = attrs.get("atom_map", node) if atom_map in (None, 0): atom_map = node x, y = pos[node] dx, dy = _index_offset_vec(node, display, pos, base=atom_map_offset) ax.text( x + dx, y + dy, str(atom_map), ha="center", va="center", fontsize=atom_map_font_size, fontweight="bold", color="#111827", path_effects=[pe.withStroke(linewidth=2.5, foreground="white")], zorder=11, ) if show_electron_labels: electron_label = attrs.get(f"its_electron_label_{electron_label_mode}", "") if electron_label: x, y = pos[node] _, dy = _index_offset_vec( node, display, pos, base=atom_map_offset * 2.35 ) ax.text( x, y - abs(dy), electron_label, ha="center", va="center", fontsize=max(7, element_font_size - 1), color="#374151", bbox={ "boxstyle": "round,pad=0.16", "fc": "white", "ec": "#cbd5e1", "alpha": 0.92, }, zorder=12, ) if title: ax.set_title(title, fontsize=12, fontweight="bold", pad=8) _set_padded_limits(ax, pos, avg_len) if fig is not None: fig.tight_layout() return ax
def _its_to_side_graphs(its: nx.Graph) -> Tuple[nx.Graph, nx.Graph]: if _has_direct_tuple_attrs(its): reverter = ITSReverter(its) return ( reverter.to_reactant_graph(recompute_neighbors=True), reverter.to_product_graph(recompute_neighbors=True), ) return its_decompose(its) def _its_display_graph(its: nx.Graph) -> nx.Graph: reactant, product = _its_to_side_graphs(its) display = nx.compose(reactant, product) for node in display.nodes: display.nodes[node]["its_changed"] = False electron_labels = _electron_node_labels( reactant.nodes[node] if node in reactant else {}, product.nodes[node] if node in product else {}, ) for key, label in electron_labels.items(): display.nodes[node][f"its_electron_label_{key}"] = label for key in ("element", "charge", "hcount", "radical", "lone_pairs"): r_value = reactant.nodes[node].get(key) if node in reactant else None p_value = product.nodes[node].get(key) if node in product else None if r_value != p_value: display.nodes[node]["its_changed"] = True break for u, v in display.edges(): r_data = reactant.get_edge_data(u, v) p_data = product.get_edge_data(u, v) r_order = _edge_order_value(r_data) p_order = _edge_order_value(p_data) state = _edge_state(r_order, p_order) display.edges[u, v]["its_state"] = state display.edges[u, v]["display_order"] = max(r_order, p_order, 1.0) display.edges[u, v]["display_aromatic"] = _is_display_aromatic(r_data, p_data) display.edges[u, v]["order"] = display.edges[u, v]["display_order"] display.edges[u, v][ "its_label_kekule" ] = f"{_fmt_order(r_order)}{_fmt_order(p_order)}" display.edges[u, v]["its_label_sigma_pi"] = _sigma_pi_label(r_data, p_data) if state != "unchanged": display.nodes[u]["its_changed"] = True display.nodes[v]["its_changed"] = True return display def _edge_order_value(attrs: Optional[dict[str, Any]]) -> float: if not attrs: return 0.0 value = attrs.get("kekule_order", attrs.get("order", 1.0)) try: return float(value) except (TypeError, ValueError): return 0.0 def _edge_state(before: float, after: float) -> str: if abs(before - after) < 1e-9: return "unchanged" if before == 0 and after > 0: return "formed" if before > 0 and after == 0: return "broken" return "order_changed" def _is_display_aromatic( reactant_attrs: Optional[dict[str, Any]], product_attrs: Optional[dict[str, Any]], ) -> bool: return any( attrs is not None and _edge_is_aromatic(attrs) for attrs in (reactant_attrs, product_attrs) ) def _state_color(state: str) -> str: return { "formed": "#15803d", "broken": "#b91c1c", "order_changed": "#ca8a04", "unchanged": "#374151", }.get(state, "#374151") def _fmt_order(order: float) -> str: if order == 0: return "∅" if float(order).is_integer(): return str(int(order)) return f"{order:g}" def _sigma_pi_label( reactant_attrs: Optional[dict[str, Any]], product_attrs: Optional[dict[str, Any]], ) -> str: r_sigma = _specific_order_value(reactant_attrs, "sigma_order") p_sigma = _specific_order_value(product_attrs, "sigma_order") r_pi = _specific_order_value(reactant_attrs, "pi_order") p_pi = _specific_order_value(product_attrs, "pi_order") parts = [] if abs(r_sigma - p_sigma) > 1e-9: parts.append(f{_fmt_order(r_sigma)}{_fmt_order(p_sigma)}") if abs(r_pi - p_pi) > 1e-9: parts.append(f{_fmt_order(r_pi)}{_fmt_order(p_pi)}") return " ".join(parts) def _specific_order_value(attrs: Optional[dict[str, Any]], key: str) -> float: if not attrs: return 0.0 value = attrs.get(key, 0.0) if value is None: return 0.0 try: return float(value) except (TypeError, ValueError): return 0.0 def _electron_node_labels( reactant_attrs: dict[str, Any], product_attrs: dict[str, Any], ) -> dict[str, str]: labels: dict[str, str] = {} all_parts = [] for key, mode, label in ( ("charge", "charge", "q"), ("lone_pairs", "lone_pair", "λ"), ("radical", "radical", "rad"), ): before = reactant_attrs.get(key, 0) after = product_attrs.get(key, 0) if before != after: formatter = _fmt_signed if key == "charge" else _fmt_count text = f"{label}{formatter(before)}{formatter(after)}" labels[mode] = text all_parts.append(text) labels["all"] = " ".join(all_parts) return labels def _fmt_signed(value: Any) -> str: try: number = int(value) except (TypeError, ValueError): return str(value) if number > 0: return f"+{number}" if number < 0: return str(number) return "0" def _fmt_count(value: Any) -> str: try: number = int(value) except (TypeError, ValueError): return str(value) return str(number) def _avg_edge_length(pos: dict[Any, tuple[float, float]], graph: nx.Graph) -> float: if graph.number_of_edges() == 0: return 1.0 lengths = [ ((pos[v][0] - pos[u][0]) ** 2 + (pos[v][1] - pos[u][1]) ** 2) ** 0.5 for u, v in graph.edges() ] return sum(lengths) / len(lengths) def _has_direct_tuple_attrs(its: nx.Graph) -> bool: node_keys = ("element", "hcount", "charge", "radical", "lone_pairs", "present") edge_keys = ("kekule_order", "sigma_order", "pi_order") for _, attrs in its.nodes(data=True): if any(_is_plain_pair(attrs.get(key)) for key in node_keys): return True for _, _, attrs in its.edges(data=True): if any(_is_plain_pair(attrs.get(key)) for key in edge_keys): return True return False def _is_plain_pair(value: object) -> bool: return ( isinstance(value, tuple) and len(value) == 2 and not any(isinstance(item, (tuple, list, set, dict)) for item in value) )