Source code for synkit.Vis.molecule_drawer

from __future__ import annotations

"""Chemistry-oriented molecular graph drawing.

This module draws scalar molecular ``nx.Graph`` objects as molecule-like
figures.  It is adapted from the copied ``vis_synedu`` renderer, but uses
SynKit's own graph-to-mol conversion and avoids relying on broken copied
relative imports.
"""

import math
from typing import Any, Dict, Mapping, Optional, Set, Tuple

import matplotlib.patches as mpatches
import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import networkx as nx
from rdkit import Chem
from rdkit.Chem import AllChem, Draw

from synkit.IO.graph_to_mol import GraphToMol

ELEMENT_PALETTE: Dict[str, Tuple[str, str]] = {
    "C": ("#5f6368", "#3d4145"),
    "H": ("#f8fafc", "#94a3b8"),
    "O": ("#e8524a", "#b83830"),
    "N": ("#5b8dd9", "#3a65b0"),
    "S": ("#e8a838", "#b87909"),
    "P": ("#e878c8", "#b84898"),
    "F": ("#5bc8af", "#2a9178"),
    "Cl": ("#3dbe6c", "#1e8a46"),
    "Br": ("#a0522d", "#6b3118"),
    "I": ("#8c54c8", "#5e2fa0"),
    "B": ("#d6a77a", "#9a6a44"),
    "Si": ("#f0c8a0", "#b88860"),
}

DEFAULT_FILL = "#a0a0a0"
DEFAULT_BORDER = "#606060"


[docs] def draw_molecule_graph( # noqa: C901 graph: nx.Graph, *, ax: Optional[plt.Axes] = None, title: Optional[str] = None, label_mode: str = "hetero", show_atom_map: bool = False, show_bond_order: bool = False, aromatic_style: str = "circle", include_rdkit_panel: bool = False, use_h_count: bool = False, node_size: Optional[int] = None, bond_width: Optional[float] = None, figsize: Tuple[float, float] = (6.0, 5.0), highlight_nodes: Optional[Set[Any]] = None, highlight_edges: Optional[Set[Tuple[Any, Any]]] = None, highlight_color: str = "#f97316", custom_node_colors: Optional[Mapping[Any, str]] = None, ) -> plt.Axes | tuple[plt.Figure, tuple[plt.Axes, plt.Axes]]: """Draw a scalar molecular graph using RDKit coordinates when possible. :param graph: Molecular NetworkX graph with scalar ``element`` and ``order`` attributes. :type graph: nx.Graph :param ax: Optional Matplotlib axes. :type ax: Optional[plt.Axes] :param title: Optional title. :type title: Optional[str] :param label_mode: ``"all"``, ``"hetero"``, or ``"none"``. :type label_mode: str :param show_atom_map: Show atom-map numbers near atoms. :type show_atom_map: bool :param show_bond_order: Show numeric bond order labels. :type show_bond_order: bool :param aromatic_style: ``"circle"`` or ``"dashed"``. :type aromatic_style: str :param include_rdkit_panel: Also show RDKit's own rendering side-by-side. :type include_rdkit_panel: bool :param use_h_count: Pass graph ``hcount`` to ``GraphToMol`` for layout. :type use_h_count: bool :returns: Axes, or ``(fig, (rdkit_ax, graph_ax))`` when ``include_rdkit_panel=True``. :rtype: Union[plt.Axes, Tuple[plt.Figure, Tuple[plt.Axes, plt.Axes]]] """ label_mode = label_mode.lower() aromatic_style = aromatic_style.lower() if label_mode not in {"all", "hetero", "none"}: raise ValueError("label_mode must be one of: all, hetero, none") if aromatic_style not in {"circle", "dashed"}: raise ValueError("aromatic_style must be one of: circle, dashed") graph_view = graph.copy() nodes = list(graph_view.nodes()) n_nodes = max(1, len(nodes)) if include_rdkit_panel: fig, (ax_rdkit, ax_graph) = plt.subplots( 1, 2, figsize=(figsize[0] * 2, figsize[1]), facecolor="white" ) elif ax is None: fig, ax_graph = plt.subplots(figsize=figsize, facecolor="white") ax_rdkit = None else: fig = ax.figure ax_graph = ax ax_rdkit = None ax_graph.clear() ax_graph.set_facecolor("white") ax_graph.set_axis_off() ax_graph.set_aspect("equal") pos = _layout_positions(graph_view, nodes, use_h_count=use_h_count) avg_len = _avg_edge_length(pos, graph_view) bond_offset = avg_len * 0.09 atom_map_offset = avg_len * 0.18 scaled_node_size = ( node_size if node_size is not None else max(180, min(560, 4600 // n_nodes)) ) scaled_bond_width = ( bond_width if bond_width is not None else max(1.3, min(2.6, 24 / n_nodes)) ) element_font_size = max(7, min(12, 100 // n_nodes)) atom_map_font_size = max(7, element_font_size) normalized_highlight_edges = _normalize_edge_set(highlight_edges) _draw_highlights( ax_graph, graph_view, pos, highlight_nodes=highlight_nodes, highlight_edges=normalized_highlight_edges, node_size=scaled_node_size, bond_width=scaled_bond_width, color=highlight_color, ) for u, v, attrs in graph_view.edges(data=True): p1, p2 = pos[u], pos[v] aromatic = _edge_is_aromatic(attrs) order = _edge_order(attrs, aromatic=aromatic) _draw_bond_lines( ax_graph, p1, p2, order=order, aromatic=aromatic, aromatic_style=aromatic_style, offset=bond_offset, lw=scaled_bond_width, color="#262a2f", ) if show_bond_order and not aromatic: _draw_bond_order_label(ax_graph, p1, p2, order) if aromatic_style == "circle": _draw_aromatic_circles(ax_graph, graph_view, pos, scale=0.52) node_fills = [] node_borders = [] for node in nodes: element = str(graph_view.nodes[node].get("element", "C")) fill, border = _element_colors(element) if custom_node_colors and node in custom_node_colors: fill = custom_node_colors[node] border = fill node_fills.append(fill) node_borders.append(border) node_artist = nx.draw_networkx_nodes( graph_view, pos, nodelist=nodes, node_color=node_fills, edgecolors=node_borders, linewidths=max(1.0, scaled_node_size**0.5 * 0.065), node_size=scaled_node_size, ax=ax_graph, ) node_artist.set_zorder(3) for node in nodes: attrs = graph_view.nodes[node] text = _element_label(attrs, label_mode=label_mode) if not text: continue x, y = pos[node] fill, _ = _element_colors(str(attrs.get("element", "C"))) ax_graph.text( x, y, text, ha="center", va="center", fontsize=element_font_size, fontweight="bold", color="white" if _luminance(fill) < 0.5 else "#1f2937", zorder=8, ) if show_atom_map: for node in nodes: atom_map = graph_view.nodes[node].get("atom_map", node) if atom_map in (None, 0): atom_map = node x, y = pos[node] dx, dy = _index_offset_vec(node, graph_view, pos, base=atom_map_offset) ax_graph.text( x + dx, y + dy, str(atom_map), ha="center", va="center", fontsize=atom_map_font_size, fontweight="bold", color="#111827", path_effects=[pe.withStroke(linewidth=2.5, foreground="white")], zorder=9, ) if title: ax_graph.set_title(title, fontsize=12, fontweight="bold", pad=8) _set_padded_limits(ax_graph, pos, avg_len) if include_rdkit_panel and ax_rdkit is not None: _draw_rdkit_panel(ax_rdkit, graph_view, nodes, use_h_count=use_h_count) fig.tight_layout() return fig, (ax_rdkit, ax_graph) fig.tight_layout() return ax_graph
def _layout_positions( graph: nx.Graph, nodes: list[Any], *, use_h_count: bool, ) -> Dict[Any, Tuple[float, float]]: try: ordered = _ordered_graph(graph, nodes) mol = _graph_to_mol(ordered, sanitize=True, use_h_count=use_h_count) _ensure_2d(mol) conf = mol.GetConformer(0) return { node: (conf.GetAtomPosition(idx).x, conf.GetAtomPosition(idx).y) for idx, node in enumerate(nodes) } except Exception: return { node: (float(point[0]), float(point[1])) for node, point in nx.kamada_kawai_layout(graph).items() } def _ordered_graph(graph: nx.Graph, nodes: list[Any]) -> nx.Graph: ordered = nx.Graph() for node in nodes: ordered.add_node(node, **graph.nodes[node]) for u, v, attrs in graph.edges(data=True): ordered.add_edge(u, v, **attrs) return ordered def _graph_to_mol(graph: nx.Graph, *, sanitize: bool, use_h_count: bool) -> Chem.Mol: converter = GraphToMol( { "element": "element", "charge": "charge", "atom_map": "atom_map", "radical": "radical", }, {"order": "order"}, ) try: return converter.graph_to_mol(graph, sanitize=sanitize, use_h_count=use_h_count) except Exception: return converter.graph_to_mol(graph, sanitize=False, use_h_count=use_h_count) def _ensure_2d(mol: Chem.Mol) -> None: if mol.GetNumConformers() == 0: AllChem.Compute2DCoords(mol) def _element_colors(element: str) -> Tuple[str, str]: return ELEMENT_PALETTE.get(element, (DEFAULT_FILL, DEFAULT_BORDER)) def _element_label(attrs: Mapping[str, Any], *, label_mode: str) -> str: element = str(attrs.get("element", "C")) if label_mode == "none": return "" if label_mode == "hetero" and element == "C": charge = int(attrs.get("charge", 0) or 0) radical = int(attrs.get("radical", 0) or 0) return "C" if charge or radical else "" charge_suffix = _charge_suffix(attrs.get("charge", 0)) radical_suffix = "." * int(attrs.get("radical", 0) or 0) return f"{element}{charge_suffix}{radical_suffix}" def _charge_suffix(charge: Any) -> str: try: value = int(charge) except (TypeError, ValueError): return "" if value == 0: return "" sign = "+" if value > 0 else "-" mag = abs(value) return sign if mag == 1 else f"{sign}{mag}" def _edge_order(attrs: Mapping[str, Any], *, aromatic: bool) -> int: if aromatic: return 1 try: order = abs(float(attrs.get("kekule_order", attrs.get("order", 1.0)))) except (TypeError, ValueError): order = 1.0 return max(1, min(3, int(round(order)))) def _edge_is_aromatic(attrs: Mapping[str, Any]) -> bool: if bool(attrs.get("aromatic", False)): return True try: return float(attrs.get("order", 0.0)) == 1.5 except (TypeError, ValueError): return False def _draw_bond_lines( ax: plt.Axes, p1: Tuple[float, float], p2: Tuple[float, float], *, order: int, aromatic: bool, aromatic_style: str, offset: float, lw: float, color: str, ) -> None: kwargs = { "color": color, "linewidth": lw, "solid_capstyle": "round", "solid_joinstyle": "round", "zorder": 2, } if aromatic and aromatic_style == "dashed": ax.plot([p1[0], p2[0]], [p1[1], p2[1]], linestyle="--", **kwargs) return if aromatic or order <= 1: ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **kwargs) return dx, dy = _perp_offset(p1, p2, offset) if order == 2: ax.plot([p1[0] + dx, p2[0] + dx], [p1[1] + dy, p2[1] + dy], **kwargs) ax.plot([p1[0] - dx, p2[0] - dx], [p1[1] - dy, p2[1] - dy], **kwargs) return ax.plot([p1[0], p2[0]], [p1[1], p2[1]], **{**kwargs, "linewidth": lw * 0.9}) ax.plot( [p1[0] + dx, p2[0] + dx], [p1[1] + dy, p2[1] + dy], **{**kwargs, "linewidth": lw * 0.9}, ) ax.plot( [p1[0] - dx, p2[0] - dx], [p1[1] - dy, p2[1] - dy], **{**kwargs, "linewidth": lw * 0.9}, ) def _draw_aromatic_circles( ax: plt.Axes, graph: nx.Graph, pos: Mapping[Any, Tuple[float, float]], *, scale: float, ) -> None: for cycle in nx.cycle_basis(graph): if len(cycle) < 5: continue if not all(bool(graph.nodes[node].get("aromatic", False)) for node in cycle): continue xs = [pos[node][0] for node in cycle] ys = [pos[node][1] for node in cycle] cx, cy = sum(xs) / len(xs), sum(ys) / len(ys) radius = sum(math.hypot(x - cx, y - cy) for x, y in zip(xs, ys)) / len(xs) ax.add_patch( mpatches.Circle( (cx, cy), radius * scale, fill=False, linewidth=1.15, color="#333333", zorder=1, ) ) def _draw_highlights( ax: plt.Axes, graph: nx.Graph, pos: Mapping[Any, Tuple[float, float]], *, highlight_nodes: Optional[Set[Any]], highlight_edges: Set[Tuple[Any, Any]], node_size: int, bond_width: float, color: str, ) -> None: if highlight_edges: for u, v in graph.edges(): if _edge_key(u, v) not in highlight_edges: continue p1, p2 = pos[u], pos[v] ax.plot( [p1[0], p2[0]], [p1[1], p2[1]], color=color, linewidth=bond_width * 5.0, alpha=0.25, solid_capstyle="round", zorder=1, ) if highlight_nodes: nodes = [node for node in highlight_nodes if node in graph] if nodes: artist = nx.draw_networkx_nodes( graph, pos, nodelist=nodes, node_size=int(node_size * 1.75), node_color=color, edgecolors="none", alpha=0.22, ax=ax, ) artist.set_zorder(1) def _draw_bond_order_label( ax: plt.Axes, p1: Tuple[float, float], p2: Tuple[float, float], order: int, ) -> None: ax.text( (p1[0] + p2[0]) / 2, (p1[1] + p2[1]) / 2, str(order), fontsize=7, ha="center", va="center", color="#111827", bbox={"boxstyle": "round,pad=0.12", "fc": "white", "ec": "none", "alpha": 0.9}, zorder=8, ) def _draw_rdkit_panel( ax: plt.Axes, graph: nx.Graph, nodes: list[Any], *, use_h_count: bool, ) -> None: ax.clear() ax.set_axis_off() try: mol = _graph_to_mol( _ordered_graph(graph, nodes), sanitize=True, use_h_count=use_h_count ) _ensure_2d(mol) options = Draw.MolDrawOptions() options.addAtomIndices = True image = Draw.MolToImage(mol, size=(500, 500), kekulize=False, options=options) ax.imshow(image) ax.set_title("RDKit", fontsize=12, fontweight="bold", pad=8) except Exception as exc: ax.text(0.5, 0.5, f"RDKit render failed\n{exc}", ha="center", va="center") def _perp_offset( p1: Tuple[float, float], p2: Tuple[float, float], offset: float, ) -> Tuple[float, float]: dx, dy = p2[0] - p1[0], p2[1] - p1[1] length = math.hypot(dx, dy) if length == 0: return 0.0, 0.0 return -dy / length * offset, dx / length * offset def _index_offset_vec( node: Any, graph: nx.Graph, pos: Mapping[Any, Tuple[float, float]], *, base: float, ) -> Tuple[float, float]: x, y = pos[node] neighbors = list(graph.neighbors(node)) if not neighbors: return 0.0, base cx = sum(pos[nbr][0] for nbr in neighbors) / len(neighbors) cy = sum(pos[nbr][1] for nbr in neighbors) / len(neighbors) dx, dy = x - cx, y - cy length = math.hypot(dx, dy) if length == 0: return 0.0, base return dx / length * base, dy / length * base def _avg_edge_length( pos: Mapping[Any, Tuple[float, float]], graph: nx.Graph, ) -> float: if graph.number_of_edges() == 0: return 1.0 lengths = [ math.hypot(pos[v][0] - pos[u][0], pos[v][1] - pos[u][1]) for u, v in graph.edges() ] return sum(lengths) / len(lengths) def _set_padded_limits( ax: plt.Axes, pos: Mapping[Any, Tuple[float, float]], avg_len: float, ) -> None: if not pos: return xs = [point[0] for point in pos.values()] ys = [point[1] for point in pos.values()] x_span = max(xs) - min(xs) y_span = max(ys) - min(ys) pad = max(avg_len * 0.45, x_span * 0.08, y_span * 0.08, 0.2) ax.set_xlim(min(xs) - pad, max(xs) + pad) ax.set_ylim(min(ys) - pad, max(ys) + pad) def _normalize_edge_set(edges: Optional[Set[Tuple[Any, Any]]]) -> Set[Tuple[Any, Any]]: if not edges: return set() return {_edge_key(u, v) for u, v in edges} def _edge_key(u: Any, v: Any) -> Tuple[Any, Any]: return (u, v) if str(u) <= str(v) else (v, u) def _luminance(hex_color: str) -> float: color = hex_color.lstrip("#") red, green, blue = (int(color[i : i + 2], 16) / 255.0 for i in (0, 2, 4)) # noqa return 0.2126 * red + 0.7152 * green + 0.0722 * blue