"""syn_rule.py
================
Immutable description of a reaction template (SynRule) with canonical forms
and optional implicit‐hydrogen stripping.
Key features
------------
* **Fragment decomposition** – splits the ITS graph into rc, left, and right.
* **Implicit H‐handling** – converts explicit H nodes into hcount + h_pairs.
* **Canonicalisation** – wraps rc/left/right in SynGraph for stable signatures.
* **Value‑object semantics** – `__eq__`/`__hash__` use fragment signatures.
Quick start
-----------
>>> from synkit.Graph.syn_rule import SynRule
>>> rule = SynRule.from_smart("[CH3:1]C>>[CH2:1]C")
>>> rule.left.signature, rule.right.signature
('abc123...', 'def456...')
"""
from __future__ import annotations
from typing import Optional, Tuple
import networkx as nx
from synkit.Graph.syn_graph import SynGraph
from synkit.Graph.canon_graph import GraphCanonicaliser
from synkit.Graph.ITS.its_decompose import its_decompose
from synkit.Graph.ITS.its_reverter import ITSReverter
from synkit.Graph.Hyrogen._misc import normalize_h_pair_graph
from synkit.IO.chem_converter import (
ITSFormat,
detect_its_format,
rsmi_to_its,
gml_to_its,
)
__all__ = ["SynRule"]
[docs]
class SynRule:
"""
Immutable reaction template: rc, left, and right fragments as SynGraph Object.
Parameters
----------
rc_graph : nx.Graph
Raw reaction-centre (RC) graph.
name : str, default ``"rule"``
Identifier for the rule.
canonicaliser : Optional[GraphCanonicaliser]
Custom canonicaliser; if *None* a default is created.
canon : bool, default ``True``
If *True*, build canonical forms and SHA-256 signatures.
implicit_h : bool, default ``True``
Convert explicit hydrogens in the **rc/left/right** fragments to an
integer ``hcount`` attribute and record cross-fragment hydrogen pairs
in a ``h_pairs`` attribute.
Attributes
----------
rc : SynGraph
Wrapped reaction‐centre graph.
left : SynGraph
Wrapped left fragment.
right : SynGraph
Wrapped right fragment.
canonical_smiles : Optional[Tuple[str,str]]
Pair of left/right fragment SHA‐256 signatures (or None if canon=False).
"""
# ------------------------------------------------------------------ #
# Alternate constructors #
# ------------------------------------------------------------------ #
[docs]
@classmethod
def from_smart(
cls,
smart: str,
name: str = "rule",
canonicaliser: Optional[GraphCanonicaliser] = None,
*,
canon: bool = True,
implicit_h: bool = True,
format: ITSFormat = "typesGH",
) -> "SynRule":
"""Instantiate from a SMARTS string."""
return cls(
rsmi_to_its(smart, format=format),
name=name,
canonicaliser=canonicaliser,
canon=canon,
implicit_h=implicit_h,
format=format,
)
[docs]
@classmethod
def from_gml(
cls,
gml: str,
name: str = "rule",
canonicaliser: Optional[GraphCanonicaliser] = None,
*,
canon: bool = True,
implicit_h: bool = True,
) -> "SynRule":
"""Instantiate from a GML string."""
return cls(
gml_to_its(gml),
name=name,
canonicaliser=canonicaliser,
canon=canon,
implicit_h=implicit_h,
)
# ------------------------------------------------------------------ #
# Initialiser #
# ------------------------------------------------------------------ #
def __init__(
self,
rc: nx.Graph,
name: str = "rule",
canonicaliser: Optional[GraphCanonicaliser] = None,
*,
canon: bool = True,
implicit_h: bool = True,
format: Optional[ITSFormat] = None,
) -> None:
self._name = name
self._canon_enabled = canon
self._implicit_h = implicit_h
self._canonicaliser = canonicaliser or GraphCanonicaliser()
# Fragment decomposition
rc_graph = rc.copy()
self._format = format or detect_its_format(rc_graph)
if self._implicit_h:
rc_graph = normalize_h_pair_graph(rc_graph)
left_graph, right_graph = self._decompose(rc_graph, self._format)
# Optional H-stripping
if self._implicit_h and self._format == "typesGH":
self._strip_explicit_h(rc_graph, left_graph, right_graph)
# Update typesGH tuples with new hcount.
for node, att in rc_graph.nodes(data=True):
t0, t1 = att["typesGH"]
new_t0 = (
t0[0],
t0[1],
left_graph.nodes[node]["hcount"] + t0[2],
t0[3],
t0[4],
)
new_t1 = (
t1[0],
t1[1],
right_graph.nodes[node]["hcount"] + t1[2],
t1[3],
t1[4],
)
att["typesGH"] = (new_t0, new_t1)
left_graph, right_graph = self._decompose(rc_graph, self._format)
elif self._implicit_h and self._format == "tuple":
self._strip_explicit_h_tuple(rc_graph, left_graph, right_graph)
left_graph, right_graph = self._decompose(rc_graph, self._format)
# ---------- wrap graphs ---------------------------------------- #
self.rc = SynGraph(rc_graph, self._canonicaliser, canon=canon)
self.left = SynGraph(left_graph, self._canonicaliser, canon=canon)
self.right = SynGraph(right_graph, self._canonicaliser, canon=canon)
self.canonical_smiles: Optional[Tuple[str, str]] = (
(self.left.signature, self.right.signature) if canon else None
)
# ================================================================== #
# Private utilities #
# ================================================================== #
@staticmethod
def _decompose(rc: nx.Graph, format: ITSFormat) -> tuple[nx.Graph, nx.Graph]:
"""Return left/right fragments for either supported ITS representation."""
if format == "tuple":
reverter = ITSReverter(rc)
return reverter.to_reactant_graph(), reverter.to_product_graph()
return its_decompose(rc)
@staticmethod
def _strip_explicit_h(
rc: nx.Graph,
left: nx.Graph,
right: nx.Graph,
) -> None:
"""Remove explicit hydrogens from rc, left, right—but only when *both*
left & right agree the H should be implicit.
Otherwise an H remains explicit in all three graphs.
"""
def _removable_on(graph: nx.Graph, h: str) -> bool:
# H+ (no neighbors) ⇒ not removable
nbrs = list(graph.neighbors(h))
if not nbrs:
return False
# H–H only ⇒ not removable
if all(graph.nodes[n].get("element") == "H" for n in nbrs):
return False
# otherwise bonded to ≥1 heavy ⇒ removable
return True
def _fully_removable(h: str) -> bool:
# only remove if BOTH left and right say removable
return _removable_on(left, h) and _removable_on(right, h)
# 1) initialize hcount & h_pairs
for g in (rc, left, right):
for n, data in g.nodes(data=True):
data["hcount"] = 0
if data.get("element") != "H":
data.setdefault("h_pairs", [])
# 2) shared H: only those removable on both sides
shared = sorted(
n
for n, d in left.nodes(data=True)
if d.get("element") == "H" and right.has_node(n) and _fully_removable(n)
)
pair_id = 1
for h in shared:
for g in (left, right, rc):
if not g.has_node(h):
continue
for nbr in list(g.neighbors(h)):
if g.nodes[nbr].get("element") != "H":
g.nodes[nbr]["hcount"] += 1
# only shared H get pair-IDs
g.nodes[nbr].setdefault("h_pairs", []).append(pair_id)
g.remove_node(h)
pair_id += 1
# 3) remaining explicit H in any graph: strip only if fully_removable
for g in (rc, left, right):
for h in [n for n, d in g.nodes(data=True) if d.get("element") == "H"]:
if not _fully_removable(h):
# at least one side wants to keep it explicit → skip
continue
# else both agree → convert to implicit
for nbr in list(g.neighbors(h)):
if g.nodes[nbr].get("element") != "H":
g.nodes[nbr]["hcount"] += 1
g.remove_node(h)
@staticmethod
def _strip_explicit_h_tuple(
rc: nx.Graph,
left: nx.Graph,
right: nx.Graph,
) -> None:
"""Tuple-style equivalent of legacy explicit-H stripping."""
def _removable_on(graph: nx.Graph, h: int) -> bool:
if not graph.has_node(h):
return False
nbrs = list(graph.neighbors(h))
if not nbrs:
return False
return not all(graph.nodes[n].get("element") == "H" for n in nbrs)
def _fully_removable(h: int) -> bool:
return _removable_on(left, h) and _removable_on(right, h)
for graph in (left, right):
for _, data in graph.nodes(data=True):
if data.get("element") != "H":
data.setdefault("h_pairs", [])
data.setdefault("h_pairs_left", [])
data.setdefault("h_pairs_right", [])
data.setdefault("h_pair_atom_maps", {})
for _, data in rc.nodes(data=True):
element = data.get("element")
is_h = (
isinstance(element, tuple)
and len(element) == 2
and all(value == "H" for value in element)
)
if not is_h:
data.setdefault("h_pairs", [])
data.setdefault("h_pairs_left", [])
data.setdefault("h_pairs_right", [])
data.setdefault("h_pair_atom_maps", {})
removable = sorted(
node
for node, attrs in left.nodes(data=True)
if attrs.get("element") == "H"
and right.has_node(node)
and _fully_removable(node)
)
for pair_id, h in enumerate(removable, start=1):
atom_map = left.nodes[h].get("atom_map", h)
for side, graph in (("left", left), ("right", right)):
for nbr in list(graph.neighbors(h)):
if graph.nodes[nbr].get("element") != "H":
graph.nodes[nbr]["hcount"] += 1
graph.nodes[nbr].setdefault("h_pairs", []).append(pair_id)
graph.nodes[nbr].setdefault(f"h_pairs_{side}", []).append(
pair_id
)
graph.nodes[nbr].setdefault("h_pair_atom_maps", {})[
pair_id
] = atom_map
graph.remove_node(h)
if rc.has_node(h):
rc.remove_node(h)
for node, attrs in rc.nodes(data=True):
if node not in left or node not in right:
continue
if attrs.get("element") == ("H", "H"):
continue
left_h = left.nodes[node].get("hcount", 0)
right_h = right.nodes[node].get("hcount", 0)
attrs["hcount"] = (left_h, right_h)
attrs["h_pairs"] = sorted(
set(left.nodes[node].get("h_pairs", []))
| set(right.nodes[node].get("h_pairs", []))
)
attrs["h_pairs_left"] = sorted(left.nodes[node].get("h_pairs_left", []))
attrs["h_pairs_right"] = sorted(right.nodes[node].get("h_pairs_right", []))
attrs["h_pair_atom_maps"] = {
**left.nodes[node].get("h_pair_atom_maps", {}),
**right.nodes[node].get("h_pair_atom_maps", {}),
}
typesgh = attrs.get("typesGH")
if typesgh and len(typesgh) == 2:
react_attr, prod_attr = typesgh
attrs["typesGH"] = (
tuple(list(react_attr[:2]) + [left_h] + list(react_attr[3:])),
tuple(list(prod_attr[:2]) + [right_h] + list(prod_attr[3:])),
)
# ================================================================== #
# Dunder methods #
# ================================================================== #
def __eq__(self, other: object) -> bool:
return (
isinstance(other, SynRule)
and self.canonical_smiles == other.canonical_smiles
)
def __hash__(self) -> int:
return hash(self.canonical_smiles)
def __str__(self) -> str:
if self._canon_enabled and self.canonical_smiles:
ls, rs = self.canonical_smiles
return f"<SynRule {self._name!r} left={ls[:8]}… right={rs[:8]}…>"
return f"<SynRule {self._name!r} (raw only)>"
def __repr__(self) -> str:
try:
v_rc, e_rc = self.rc.raw.number_of_nodes(), self.rc.raw.number_of_edges()
v_l, e_l = self.left.raw.number_of_nodes(), self.left.raw.number_of_edges()
v_r, e_r = (
self.right.raw.number_of_nodes(),
self.right.raw.number_of_edges(),
)
except Exception:
v_rc = e_rc = v_l = e_l = v_r = e_r = 0
return (
f"SynRule(name={self._name!r}, "
f"rc=(|V|={v_rc},|E|={e_rc}), "
f"left=(|V|={v_l},|E|={e_l}), "
f"right=(|V|={v_r},|E|={e_r}))"
)
# ================================================================== #
# Public API #
# ================================================================== #
[docs]
def help(self) -> None:
"""Pretty-print raw / canonical contents for quick inspection."""
print(f"SynRule name={self._name!r}")
print("→ Full (raw) rc_graph edges:")
for u, v, d in self.rc.raw.edges(data=True):
print(f" ({u}, {v}): {d}")
if not self._canon_enabled:
print("→ Canonicalisation disabled.")
return
print("\n→ Full canonical_graph edges:")
for u, v, d in self.rc.canonical.edges(data=True): # type: ignore[attr-defined]
print(f" ({u}, {v}): {d}")
print("\n→ Left fragment:")
self.left.help()
print("\n→ Right fragment:")
self.right.help()
print("\n→ Fragment signatures:", self.canonical_smiles)