from __future__ import annotations
"""MTG – Mechanistic Transition Graph fusion utility.
This module exposes :class:`~MTG`, a helper that merges a chronological
sequence of **Intermediate Transition State** (ITS) graphs – or their RSMI
string representations – into a single *product* graph capturing the entire
bond-order history across the reaction trajectory.
The implementation is self-contained except for the external *synkit* helpers
used for RSMI⇒ITS inter-conversion and canonicalisation.
"""
from collections.abc import Iterator
from typing import Any, Dict, List, Mapping, MutableMapping, Set, Tuple, Union
import networkx as nx
# ---------------------------------------------------------------------------
# Optional dependencies
# ---------------------------------------------------------------------------
try:
import pandas as pd # type: ignore
except ImportError: # pragma: no cover – pandas is only required for to_dataframe()
pd = None # noqa: N816
from synkit.Graph.Hyrogen._misc import h_to_explicit
from synkit.Graph.ITS.normalize_aam import NormalizeAAM
from synkit.Graph.MTG.mcs_matcher import MCSMatcher
from synkit.Graph.MTG.utils import (
normalize_hcount_and_typesGH,
normalize_order,
label_mtg_edges,
compute_standard_order,
)
from synkit.Graph.canon_graph import GraphCanonicaliser
from synkit.IO import ITSFormat, its_to_rsmi, rsmi_to_its
NodeID = int
MissingOrder = Tuple[Set[float], Set[float]]
GraphMapping = Dict[NodeID, NodeID]
_PLACEHOLDER: MissingOrder = (set(), set())
_PLACEHOLDER_TYPESGH = (set(), set(), set(), set(), set())
_TUPLE_EDGE_ATTRS = ("order", "kekule_order", "sigma_order", "pi_order")
_TUPLE_NODE_SCALAR_ATTRS = ("element", "atom_map", "valence_electrons")
_TUPLE_NODE_TIMELINE_ATTRS = (
"aromatic",
"hcount",
"charge",
"radical",
"lone_pairs",
"present",
)
__all__ = ["MTG"]
[docs]
class MTG:
"""Fuse a chronological series of ITS graphs into a Mechanistic Transition Graph.
:param sequences: A list of ITS-format NetworkX graphs or RSMI strings.
:param mappings: Optional list of precomputed mappings; computed via MCS if None.
:param node_label_names: Keys for node-label matching.
:param canonicaliser: Optional GraphCanonicaliser for snapshot canonicalisation.
:param its_format: ITS format used when ``sequences`` contains RSMI strings.
Defaults to ``"tuple"`` for Lewis State Graph MTGs. Pass
``"typesGH"`` to build legacy MTGs from strings.
:raises ValueError: On invalid sequence or mapping lengths.
:raises RuntimeError: On mapping failures.
"""
def __init__(
self,
sequences: Union[List[nx.Graph], List[str]],
mappings: List[GraphMapping] | None = None,
*,
node_label_names: List[str] | None = None,
canonicaliser: GraphCanonicaliser | None = None,
mcs_mol: bool = False,
mcs: bool = False,
its_format: ITSFormat = "tuple",
) -> None:
if len(sequences) < 2:
raise ValueError("Need at least two snapshots.")
self._node_label_names = node_label_names or ["element", "charge", "hcount"]
self._canonicaliser = canonicaliser
self.mcs_mol = mcs_mol
self.mcs = mcs
self.its_format = its_format
self._graphs = self._prepare_graph_sequence(sequences)
self._k = len(self._graphs)
self._tuple_its = all(self._is_tuple_its(g) for g in self._graphs)
self._mappings = (
mappings if mappings is not None else self._compute_mappings(self._graphs)
)
if len(self._mappings) != self._k - 1:
raise ValueError("Mappings must match snapshot pairs.")
self._prod_nodes: Dict[int, Dict[str, Any]]
self._node_map: Dict[Tuple[int, NodeID], int]
self._graph: nx.Graph
self._build_node_map_and_attributes()
self._build_edge_history_and_graph()
def __repr__(self) -> str:
return f"<MTG k={self._k} nodes={self._graph.number_of_nodes()} edges={self._graph.number_of_edges()}>"
def __len__(self) -> int:
return self._graph.number_of_nodes()
def __iter__(self) -> Iterator[int]:
return iter(self._graph.nodes)
def __getitem__(self, node_id: int) -> Dict[str, Any]:
return self._graph.nodes[node_id]
[docs]
@staticmethod
def describe() -> str:
return (
"# Usage example\n"
"mtg = MTG([G0, G1, G2])\n"
"mg = mtg.get_mtg()\n"
"rsmi = mtg.get_aam()\n"
)
[docs]
def get_mtg(self, *, directed: bool = False) -> nx.Graph:
return self._graph.to_directed() if directed else self._graph
[docs]
def get_its_steps(self, *, directed: bool = False) -> List[nx.Graph]:
"""Reconstruct the ordered list of per-step ITS graphs from the MTG."""
if not self._tuple_its:
return [graph.copy() for graph in self._graphs]
graph = self.get_mtg(directed=directed)
return [self._tuple_step_its(graph, step) for step in range(self._k)]
[docs]
def get_rsmi_steps(
self,
*,
directed: bool = False,
explicit_hydrogen: bool = False,
sanitize: bool = True,
) -> List[str]:
"""Serialize reconstructed per-step ITS graphs to reaction SMILES."""
fmt = "tuple" if self._tuple_its else "typesGH"
return [
its_to_rsmi(
its,
format=fmt,
explicit_hydrogen=explicit_hydrogen,
sanitize=sanitize,
)
for its in self.get_its_steps(directed=directed)
]
[docs]
def get_compose_its(self, *, directed: bool = False) -> nx.Graph:
g = self.get_mtg(directed=directed)
if self._tuple_its:
g = self._compose_tuple_node_attrs(g)
g = self._compose_tuple_edge_attrs(g)
else:
g = label_mtg_edges(g, inplace=False)
g = normalize_order(g)
g = normalize_hcount_and_typesGH(g)
return compute_standard_order(g)
[docs]
def get_aam(self, *, directed: bool = False, explicit_h: bool = False) -> str:
g = self.get_compose_its(directed=directed)
rsmi = its_to_rsmi(g, explicit_hydrogen=True)
return (
NormalizeAAM().fit(rsmi, fix_aam_indice=False) if not explicit_h else rsmi
)
[docs]
def to_dataframe(self):
if pd is None:
raise RuntimeError("pandas required for DataFrame export.")
return pd.DataFrame.from_dict(
dict(self._graph.nodes(data=True)), orient="index"
)
@staticmethod
def _merge_attrs(lhs: MutableMapping[str, Any], rhs: Mapping[str, Any]) -> None:
for k, v in rhs.items():
if not lhs.get(k) and v is not None:
lhs[k] = v
def _build_node_map_and_attributes(self) -> None:
if self._tuple_its and self._has_tuple_atom_maps(self._graphs):
self._build_tuple_node_map_and_attributes()
return
prod, node_map = {}, {}
last = self._graphs[-1]
for nid, attrs in last.nodes(data=True):
prod[nid] = attrs.copy()
node_map[(self._k - 1, nid)] = nid
pid_counter = max(prod, default=-1) + 1
# merge attributes backwards
for i in range(self._k - 2, -1, -1):
G, mp = self._graphs[i], self._mappings[i]
for nid, attrs in G.nodes(data=True):
tgt = mp.get(nid)
if tgt is not None and (i + 1, tgt) in node_map:
pid = node_map[(i + 1, tgt)]
self._merge_attrs(prod[pid], attrs)
else:
pid = pid_counter
prod[pid] = attrs.copy()
pid_counter += 1
node_map[(i, nid)] = pid
# assemble typesGH history per pid
first_idx: Dict[int, int] = {}
for (gi, n), p in node_map.items():
# track the earliest snapshot index where pid appears
if p in first_idx:
first_idx[p] = min(first_idx[p], gi)
else:
first_idx[p] = gi
if self._tuple_its:
self._simplify_tuple_node_attrs(prod, node_map)
else:
for p, attrs in prod.items():
hist: List[Any] = []
fi = first_idx[p]
for i in range(self._k):
if i < fi:
hist.append(_PLACEHOLDER_TYPESGH)
elif i == fi:
val = (
self._graphs[i]
.nodes[
next(
n
for (gi, n), pp in node_map.items()
if gi == i and pp == p
)
]
.get(
"typesGH", (_PLACEHOLDER_TYPESGH, _PLACEHOLDER_TYPESGH)
)
)
hist.append(val)
else:
originals = [
n for (gi, n), pp in node_map.items() if gi == i and pp == p
]
if originals:
val = (
self._graphs[i]
.nodes[originals[0]]
.get(
"typesGH",
(_PLACEHOLDER_TYPESGH, _PLACEHOLDER_TYPESGH),
)[-1]
)
hist.append(val)
else:
hist.append(_PLACEHOLDER_TYPESGH)
attrs["typesGH_history"] = tuple(hist)
attrs["typesGH"] = attrs["typesGH_history"]
self._prod_nodes = prod
self._node_map = node_map
def _build_tuple_node_map_and_attributes(self) -> None:
prod: Dict[int, Dict[str, Any]] = {}
node_map: Dict[Tuple[int, NodeID], int] = {}
pid_counter = 0
for gi, graph in enumerate(self._graphs):
used_in_graph: Set[int] = set()
for nid, attrs in graph.nodes(data=True):
pid = self._tuple_node_pid(attrs)
if pid is None or pid in used_in_graph:
while pid_counter in prod:
pid_counter += 1
pid = pid_counter
pid_counter += 1
prod.setdefault(pid, {})
node_map[(gi, nid)] = pid
used_in_graph.add(pid)
self._simplify_tuple_node_attrs(prod, node_map)
self._prod_nodes = prod
self._node_map = node_map
def _build_edge_history_and_graph(self) -> None:
hist: Dict[Tuple[int, int], Dict[str, List[MissingOrder]]] = {}
for i, G in enumerate(self._graphs):
for u, v, a in G.edges(data=True):
pu, pv = self._node_map[(i, u)], self._node_map[(i, v)]
key = tuple(sorted((pu, pv)))
attr_hist = hist.setdefault(
key,
{name: [_PLACEHOLDER] * self._k for name in _TUPLE_EDGE_ATTRS},
)
for name in _TUPLE_EDGE_ATTRS:
attr_hist[name][i] = a.get(name, _PLACEHOLDER)
g = nx.Graph()
g.add_nodes_from(self._prod_nodes.items())
for (u, v), attr_hist in hist.items():
attrs: Dict[str, Any] = {"order": tuple(attr_hist["order"])}
if self._tuple_its:
attrs = {}
for name, values in attr_hist.items():
attrs[name] = self._edge_pair_history_to_timeline(
tuple(values),
g.nodes[u].get("present"),
g.nodes[v].get("present"),
)
attrs["steps"] = tuple(
i
for i, value in enumerate(attr_hist["order"])
if self._is_observed_pair(value)
)
g.add_edge(u, v, **attrs)
if g.number_of_nodes() != len(self._prod_nodes):
raise RuntimeError("Node count mismatch.")
self._graph = g
def _simplify_tuple_node_attrs(
self,
prod: Dict[int, Dict[str, Any]],
node_map: Dict[Tuple[int, NodeID], int],
) -> None:
"""
Replace tuple-ITS node attrs with compact MTG attrs.
A path of ``k`` ITS steps has ``k + 1`` mechanism states: the first
step's left side followed by each step's right side.
"""
refs_by_pid: Dict[int, Dict[int, NodeID]] = {}
for (gi, nid), pid in node_map.items():
refs_by_pid.setdefault(pid, {})[gi] = nid
for pid, refs in refs_by_pid.items():
simplified: Dict[str, Any] = {}
for key in _TUPLE_NODE_SCALAR_ATTRS:
timeline = self._node_attr_timeline(refs, key)
simplified[key] = next(
(value for value in timeline if value is not None),
None,
)
for key in _TUPLE_NODE_TIMELINE_ATTRS:
simplified[key] = self._node_attr_timeline(refs, key)
simplified["steps"] = tuple(sorted(refs))
prod[pid] = simplified
def _node_attr_timeline(
self,
refs: Dict[int, NodeID],
key: str,
) -> Tuple[Any, ...]:
timeline: List[Any] = [None] * (self._k + 1)
for gi in range(self._k):
nid = refs.get(gi)
if nid is None:
continue
value = self._graphs[gi].nodes[nid].get(key)
if self._is_pair(value):
timeline[gi] = value[0]
timeline[gi + 1] = value[1]
return tuple(timeline)
def _compose_tuple_node_attrs(self, graph: nx.Graph) -> nx.Graph:
"""
Collapse tuple-ITS node histories to the outermost observed states.
The fused MTG node attrs are initially copied from the last ITS step.
For a composed ITS we instead need the first available left-side value
and the last available right-side value across the whole trajectory.
"""
out = graph.copy()
for _, attrs in out.nodes(data=True):
for key in _TUPLE_NODE_SCALAR_ATTRS:
value = attrs.get(key)
attrs[key] = (value, value)
for key in _TUPLE_NODE_TIMELINE_ATTRS:
timeline = attrs.get(key)
if isinstance(timeline, tuple) and timeline:
attrs[key] = (timeline[0], timeline[-1])
return out
def _compose_tuple_edge_attrs(self, graph: nx.Graph) -> nx.Graph:
"""Collapse tuple edge timelines to first-state / final-state pairs."""
out = graph.copy()
for _, _, attrs in out.edges(data=True):
for name in _TUPLE_EDGE_ATTRS:
timeline = attrs.get(name)
if not isinstance(timeline, tuple):
continue
if timeline:
attrs[name] = (timeline[0], timeline[-1])
return out
def _tuple_step_its(self, graph: nx.Graph, step: int) -> nx.Graph:
"""Extract one paired tuple ITS step from compact tuple-MTG timelines."""
its = nx.Graph()
for node, attrs in graph.nodes(data=True):
node_attrs: Dict[str, Any] = {}
if step not in attrs.get("steps", ()):
continue
present_pair = self._timeline_pair(attrs.get("present"), step)
if present_pair[0] is None or present_pair[1] is None:
continue
for key in _TUPLE_NODE_SCALAR_ATTRS:
value = attrs.get(key)
node_attrs[key] = (value, value)
for key in _TUPLE_NODE_TIMELINE_ATTRS:
value = self._timeline_pair(attrs.get(key), step)
if value != (None, None):
node_attrs[key] = value
its.add_node(node, **node_attrs)
for u, v, attrs in graph.edges(data=True):
if step not in attrs.get("steps", ()):
continue
edge_attrs: Dict[str, Any] = {}
has_edge = False
for key in _TUPLE_EDGE_ATTRS:
value = self._timeline_pair(attrs.get(key), step)
if value == (None, None):
continue
edge_attrs[key] = value
if (
key == "order"
and value[0] is not None
and value[1] is not None
and value != (0, 0)
and value != (0.0, 0.0)
):
has_edge = True
if has_edge and u in its and v in its:
its.add_edge(u, v, **edge_attrs)
return compute_standard_order(its)
def _prepare_graph_sequence(
self, seq: List[nx.Graph] | List[str]
) -> List[nx.Graph]:
out: List[nx.Graph] = []
for item in seq:
g = (
rsmi_to_its(item, core=False, format=self.its_format)
if isinstance(item, str)
else item
)
if self._canonicaliser:
g = self._canonicaliser.canonicalise_graph(g).canonical_graph
if self._is_tuple_its(g):
out.append(g)
continue
g = h_to_explicit(g, its=True)
out.append(normalize_hcount_and_typesGH(g))
return out
@staticmethod
def _is_tuple_its(graph: nx.Graph) -> bool:
"""
Detect paired-attribute ITS graphs produced by the newer tuple format.
Tuple ITS nodes carry side-specific attributes directly, such as
``element=("C", "C")`` and ``lone_pairs=(0, 0)``. Legacy ITS graphs
instead keep the paired state primarily in ``typesGH``.
"""
if graph.number_of_nodes() == 0:
return False
_, attrs = next(iter(graph.nodes(data=True)))
element = attrs.get("element")
return isinstance(element, tuple) and len(element) == 2
@staticmethod
def _is_pair(value: Any) -> bool:
return isinstance(value, tuple) and len(value) == 2
@classmethod
def _is_observed_pair(cls, value: Any) -> bool:
return cls._is_pair(value) and not (
isinstance(value[0], set) and isinstance(value[1], set)
)
@staticmethod
def _timeline_pair(timeline: Any, step: int) -> Tuple[Any, Any]:
if not isinstance(timeline, tuple) or len(timeline) <= step + 1:
return (None, None)
return (timeline[step], timeline[step + 1])
@classmethod
def _edge_pair_history_to_timeline(
cls,
history: Tuple[Any, ...],
u_present: Any,
v_present: Any,
) -> Tuple[Any, ...]:
"""
Convert ITS step-pair history into mechanism-state timeline.
Example: ``((2, 1), (1, 2))`` becomes ``(2, 1, 2)``.
Missing edge states are ``0`` when both endpoint atoms exist and
``None`` when an endpoint is absent.
"""
if not history:
return ()
timeline: List[Any] = [None] * (len(history) + 1)
for idx, value in enumerate(history):
if cls._is_pair(value) and not (
isinstance(value[0], set) and isinstance(value[1], set)
):
timeline[idx] = value[0]
timeline[idx + 1] = value[1]
return tuple(
cls._fill_missing_edge_state(value, idx, u_present, v_present)
for idx, value in enumerate(timeline)
)
@staticmethod
def _fill_missing_edge_state(
value: Any,
idx: int,
u_present: Any,
v_present: Any,
) -> Any:
if value is not None:
return value
if (
isinstance(u_present, tuple)
and isinstance(v_present, tuple)
and len(u_present) > idx
and len(v_present) > idx
and u_present[idx] is True
and v_present[idx] is True
):
return 0.0
return None
def _compute_mappings(self, graphs: List[nx.Graph]) -> List[GraphMapping]:
if self._tuple_its:
return [
self._compute_tuple_mapping(graphs[i], graphs[i + 1])
for i in range(len(graphs) - 1)
]
mappings: List[GraphMapping] = []
for i in range(len(graphs) - 1):
m = MCSMatcher(node_label_names=self._node_label_names)
m.find_rc_mapping(
graphs[i], graphs[i + 1], mcs=self.mcs, mcs_mol=self.mcs_mol
)
if not m._mappings:
raise RuntimeError(f"No mapping between {i} and {i+1}")
mappings.append(m._mappings[0])
return mappings
@classmethod
def _compute_tuple_mapping(cls, left: nx.Graph, right: nx.Graph) -> GraphMapping:
left_by_map = cls._nodes_by_atom_map(left)
right_by_map = cls._nodes_by_atom_map(right)
common_maps = sorted(set(left_by_map) & set(right_by_map))
mapping = {left_by_map[amap]: right_by_map[amap] for amap in common_maps}
if mapping:
return mapping
common_nodes = sorted(set(left.nodes()) & set(right.nodes()))
return {node: node for node in common_nodes}
@classmethod
def _has_tuple_atom_maps(cls, graphs: List[nx.Graph]) -> bool:
return any(
cls._tuple_node_pid(attrs) is not None
for graph in graphs
for _, attrs in graph.nodes(data=True)
)
@staticmethod
def _tuple_node_pid(attrs: Mapping[str, Any]) -> int | None:
atom_map = attrs.get("atom_map")
if isinstance(atom_map, tuple) and len(atom_map) == 2:
atom_map = atom_map[1] if atom_map[1] not in (None, 0, "") else atom_map[0]
if atom_map in (None, 0, ""):
return None
return int(atom_map)
@staticmethod
def _nodes_by_atom_map(graph: nx.Graph) -> Dict[int, NodeID]:
by_map: Dict[int, NodeID] = {}
for node, attrs in graph.nodes(data=True):
atom_map = MTG._tuple_node_pid(attrs)
if atom_map is None:
continue
if atom_map in by_map:
continue
by_map[atom_map] = node
return by_map
@property
def node_mapping(self) -> Dict[Tuple[int, NodeID], int]:
return dict(self._node_map)
@property
def k(self) -> int:
return self._k
# from __future__ import annotations
# """MTG – Mechanistic Transition Graph fusion utility.
# This module exposes :class:`~MTG`, a helper that merges a chronological
# sequence of **Intermediate Transition State** (ITS) graphs – or their RSMI
# string representations – into a single *product* graph capturing the entire
# bond-order history across the reaction trajectory.
# The implementation is self-contained except for the external *synkit* helpers
# used for RSMI⇆ITS inter-conversion and canonicalisation.
# """
# from collections.abc import Iterator
# from typing import Any, Dict, List, Mapping, MutableMapping, Set, Tuple, Union
# import networkx as nx
# # ---------------------------------------------------------------------------
# # Optional dependencies
# # ---------------------------------------------------------------------------
# try:
# import pandas as pd # type: ignore
# except ImportError: # pragma: no cover – pandas is only required for to_dataframe()
# pd = None # noqa: N816 – keep lowercase alias even if stubbed
# from synkit.Graph.Hyrogen._misc import h_to_explicit # noqa: WPS433 – external import
# from synkit.Graph.ITS.normalize_aam import NormalizeAAM # noqa: WPS433
# from synkit.Graph.MTG.mcs_matcher import MCSMatcher # noqa: WPS433
# from synkit.Graph.MTG.utils import (
# normalize_hcount_and_typesGH,
# normalize_order,
# label_mtg_edges,
# compute_standard_order,
# ) # noqa: WPS433
# from synkit.Graph.canon_graph import GraphCanonicaliser # noqa: WPS433
# from synkit.IO import its_to_rsmi, rsmi_to_its # noqa: WPS433
# NodeID = int
# OrderPair = Tuple[float, float]
# MissingOrder = Tuple[Set[float], Set[float]]
# GraphMapping = Dict[NodeID, NodeID]
# # A placeholder for a *missing* edge-order in a particular snapshot. Using
# # `set()` makes the value clearly distinguishable from genuine numeric orders.
# _PLACEHOLDER: MissingOrder = (set(), set())
# __all__ = [
# "MTG",
# ]
# class MTG: # pylint: disable=too-many-instance-attributes
# """Fuse a chronological series of ITS graphs into a Mechanistic Transition Graph.
# :param sequences: Either a list of ITS-format NetworkX graphs or a list of RSMI
# strings in chronological order.
# :type sequences: List[nx.Graph] or List[str]
# :param mappings: Pre-computed node mappings between each consecutive pair of graphs.
# If None, mappings are computed via MCSMatcher.
# :type mappings: List[GraphMapping] or None
# :param node_label_names: Node attribute keys used for MCS-based matching.
# :type node_label_names: List[str] or None
# :param canonicaliser: Optional GraphCanonicaliser to canonicalise each
# snapshot before fusion.
# :type canonicaliser: GraphCanonicaliser or None
# :raises ValueError: If fewer than two sequences are provided or mapping count mismatches.
# :raises TypeError: If sequence elements are neither NetworkX graphs nor RSMI strings.
# :raises RuntimeError: If automatic mapping fails for any adjacent graph pair.
# """
# # ---------------------------------------------------------------------
# # Construction helpers
# # ---------------------------------------------------------------------
# def __init__(
# self,
# sequences: Union[List[nx.Graph], List[str]],
# mappings: List[GraphMapping] | None = None,
# *,
# node_label_names: List[str] | None = None,
# canonicaliser: GraphCanonicaliser | None = None,
# mcs_mol: bool = False,
# mcs: bool = False,
# ) -> None: # noqa: D401 – docstring handled above
# # --- Basic validation ------------------------------------------------
# if len(sequences) < 2: # also covers non-list via __len__ check raising
# raise ValueError(
# "At least two ITS snapshots are required to construct an MTG.",
# )
# self._node_label_names: List[str] = node_label_names or [
# "element",
# "charge",
# "hcount",
# ]
# self._canonicaliser = canonicaliser
# self.mcs_mol: bool = mcs_mol
# self.mcs: bool = mcs
# # --- Input normalisation -------------------------------------------
# self._graphs: List[nx.Graph] = self._prepare_graph_sequence(sequences)
# self._k: int = len(self._graphs)
# # --- Graph-to-graph mappings ---------------------------------------
# self._mappings: List[GraphMapping] = (
# self._compute_mappings(self._graphs) if mappings is None else mappings
# )
# if len(self._mappings) != self._k - 1:
# raise ValueError(
# "Need exactly one mapping per pair of adjacent snapshots.",
# )
# # --- Core fusion machinery -----------------------------------------
# self._prod_nodes: Dict[int, Dict[str, Any]]
# self._node_map: Dict[Tuple[int, NodeID], int]
# self._graph: nx.Graph # final fused graph – populated below
# self._build_node_map_and_attributes()
# self._build_edge_history_and_graph()
# # ---------------------------------------------------------------------
# # Python dunder & public helpers
# # ---------------------------------------------------------------------
# def __repr__(self) -> str: # noqa: D401 – simple representation
# """Return a summary representation including snapshot count and graph size."""
# return (
# f"<MTG k={self._k} nodes={self._graph.number_of_nodes()} "
# f"edges={self._graph.number_of_edges()}>"
# )
# # Collection-like API ---------------------------------------------------
# def __len__(self) -> int:
# """Return the number of fused nodes in the product graph."""
# return self._graph.number_of_nodes()
# def __iter__(self) -> Iterator[int]:
# """Iterate over fused node identifiers."""
# return iter(self._graph.nodes)
# def __getitem__(self, node_id: int) -> Dict[str, Any]:
# """Access the attribute dictionary of a fused node by its ID.
# :param node_id: Fused node identifier
# :type node_id: int
# :returns: Node attribute mapping
# :rtype: Dict[str, Any]
# """
# return self._graph.nodes[node_id]
# # ---------------------------------------------------------------------
# # Public / user-facing API
# # ---------------------------------------------------------------------
# @staticmethod
# def describe() -> str: # noqa: D401 – simple helper
# """Return an inline usage example for quick reference."""
# return (
# "# Example usage\n"
# "mtg = MTG([G0, G1, G2])\n"
# "fused_graph = mtg.get_mtg()\n"
# "rsmi_with_aam = mtg.get_aam()\n"
# )
# # ------------------------------------------------------------------
# # Graph export helpers
# # ------------------------------------------------------------------
# def get_mtg(self, *, directed: bool = False) -> nx.Graph:
# """Return the fused product graph.
# :param directed: If True, return a directed copy of the fused graph
# :type directed: bool
# :returns: Fused product graph
# :rtype: networkx.Graph or networkx.DiGraph
# """
# return self._graph.to_directed() if directed else self._graph
# def get_compose_its(self, *, directed: bool = False) -> nx.Graph:
# """Return a graph with normalized edge orders for ITS export.
# :param directed: If True, normalize a directed version
# :type directed: bool
# :returns: Graph with collapsed (order_G, order_H) tuples
# :rtype: networkx.Graph or networkx.DiGraph
# """
# fused = self.get_mtg(directed=directed)
# fused = label_mtg_edges(fused, inplace=False)
# fused = normalize_order(fused)
# return compute_standard_order(fused)
# def get_aam(
# self,
# *,
# directed: bool = False,
# explicit_h: bool = False,
# ) -> str:
# """Export fused graph to an RSMI string with atom-atom mapping.
# :param directed: If True, use a directed ITS representation
# :type directed: bool
# :param explicit_h: If True, include explicit hydrogens; otherwise normalize AAM
# :type explicit_h: bool
# :returns: RSMI string with AAM
# :rtype: str
# """
# its_graph = self.get_compose_its(directed=directed)
# rsmi = its_to_rsmi(its_graph, explicit_hydrogen=True)
# if not explicit_h:
# rsmi = NormalizeAAM().fit(rsmi, fix_aam_indice=False)
# return rsmi
# def to_dataframe(self):
# """Return a pandas DataFrame of fused node attributes.
# :returns: DataFrame indexed by fused node IDs with attribute columns
# :rtype: pandas.DataFrame
# :raises RuntimeError: If pandas is not installed
# """
# if pd is None:
# raise RuntimeError(
# "pandas is required for `to_dataframe()` but is not installed."
# )
# return pd.DataFrame.from_dict(
# dict(self._graph.nodes(data=True)), orient="index"
# )
# # ------------------------------------------------------------------
# # Node & edge fusion internals
# # ------------------------------------------------------------------
# @staticmethod
# def _merge_attrs(lhs: MutableMapping[str, Any], rhs: Mapping[str, Any]) -> None:
# """Update in-place, preferring non-empty or non-None values from rhs.
# :param lhs: Target attribute dict to update
# :type lhs: MutableMapping[str, Any]
# :param rhs: Source attribute dict
# :type rhs: Mapping[str, Any]
# """
# for key, value in rhs.items():
# if (
# not lhs.get(key) and value is not None
# ): # noqa: WPS501 – explicitly allow False/0
# lhs[key] = value
# # .................................................................
# def _build_node_map_and_attributes(self) -> None:
# """Construct fused nodes by merging snapshots backwards.
# Builds:
# - self._prod_nodes: pid → attribute dict
# - self._node_map: (snapshot_index, original_node_id) → pid
# """
# prod: Dict[int, Dict[str, Any]] = {}
# node_map: Dict[Tuple[int, NodeID], int] = {}
# # --- Seed with last snapshot -------------------------------------
# last_graph = self._graphs[-1]
# for nid, attrs in last_graph.nodes(data=True):
# prod[nid] = attrs.copy()
# node_map[(self._k - 1, nid)] = nid
# next_pid: int = (max(prod) if prod else -1) + 1
# # --- Walk backwards and merge ------------------------------------
# for idx in range(self._k - 2, -1, -1):
# G = self._graphs[idx]
# mapping = self._mappings[idx]
# for nid, attrs in G.nodes(data=True):
# target = mapping.get(nid)
# if target is not None and (idx + 1, target) in node_map:
# pid = node_map[(idx + 1, target)]
# self._merge_attrs(prod[pid], attrs)
# else: # new (unmapped) node – assign fresh pid
# while next_pid in prod: # safeguard although unlikely
# next_pid += 1
# pid = next_pid
# prod[pid] = attrs.copy()
# next_pid += 1
# node_map[(idx, nid)] = pid
# self._prod_nodes = prod
# self._node_map = node_map
# # .................................................................
# def _build_edge_history_and_graph(
# self,
# ) -> None: # noqa: C901 – complex but contained
# """Assemble the fused graph with per-edge order histores.
# Each edge in the result has an `order` attribute: a tuple of
# length `k`, where each element is either an order-pair or a placeholder.
# """
# history: Dict[Tuple[int, int], List[MissingOrder]] = {}
# # Collect order trajectories -----------------------------------------------------
# for gi, G in enumerate(self._graphs):
# for u, v, attrs in G.edges(data=True):
# pu, pv = (
# self._node_map[(gi, u)],
# self._node_map[(gi, v)],
# )
# key = tuple(sorted((pu, pv))) # undirected canonical ordering
# orders = history.setdefault(key, [_PLACEHOLDER] * self._k)
# orders[gi] = attrs.get("order", _PLACEHOLDER) # type: ignore[arg-type]
# # Build fused NetworkX graph -----------------------------------------------------
# graph = nx.Graph()
# graph.add_nodes_from(self._prod_nodes.items())
# for (u, v), orders in history.items():
# graph.add_edge(u, v, order=tuple(orders))
# # Sanity check ------------------------------------------------------------------
# if graph.number_of_nodes() != len(self._prod_nodes):
# raise RuntimeError("Node count mismatch during MTG assembly.")
# self._graph = graph
# # ------------------------------------------------------------------
# # Mapping helpers
# # ------------------------------------------------------------------
# def _prepare_graph_sequence(
# self, seq: List[nx.Graph] | List[str]
# ) -> List[nx.Graph]:
# """Convert input list to a cleaned sequence of ITS graphs.
# :param seq: Raw sequence of graphs or RSMI strings
# :type seq: List[nx.Graph] or List[str]
# :returns: List of normalized ITS graphs
# :rtype: List[nx.Graph]
# :raises TypeError: If an element is neither nx.Graph nor str
# """
# prepared: List[nx.Graph] = []
# for item in seq:
# if isinstance(item, str):
# graph = rsmi_to_its(item, core=False)
# elif isinstance(item, nx.Graph):
# graph = item
# else: # pragma: no cover – guard against future unsupported types
# raise TypeError(
# "Sequences must contain either NetworkX graphs or RSMI strings.",
# )
# # Canonicalise (optional) ---------------------------------------------------
# if self._canonicaliser is not None:
# graph = self._canonicaliser.canonicalise_graph(graph).canonical_graph # type: ignore[attr-defined]
# # Ensure explicit hydrogens & normalised hcount / typesGH ----------
# graph = h_to_explicit(graph, its=True)
# graph = normalize_hcount_and_typesGH(graph)
# prepared.append(graph)
# return prepared
# # ..................................................................
# def _compute_mappings(self, graphs: List[nx.Graph]) -> List[GraphMapping]:
# """Compute atom mappings via MCS matching for each adjacent pair.
# :param graphs: ITS graphs in chronological order
# :type graphs: List[nx.Graph]
# :returns: List of mappings of length k-1
# :rtype: List[GraphMapping]
# :raises RuntimeError: If no mapping found for a pair
# """
# mappings: List[GraphMapping] = []
# for idx in range(len(graphs) - 1):
# matcher = MCSMatcher(node_label_names=self._node_label_names)
# matcher.find_rc_mapping(
# graphs[idx], graphs[idx + 1], mcs=self.mcs, mcs_mol=self.mcs_mol
# )
# if not matcher._mappings: # pylint: disable=protected-access
# raise RuntimeError(
# f"No MCS mapping found between snapshots {idx} and {idx + 1}.",
# )
# mappings.append(matcher._mappings[0]) # pylint: disable=protected-access
# return mappings
# # ------------------------------------------------------------------
# # Convenience accessors (mostly for unit tests)
# # ------------------------------------------------------------------
# @property
# def node_mapping(self) -> Dict[Tuple[int, NodeID], int]:
# """Return the internal mapping from (snapshot_index, original_node_id) to fused pid.
# :returns: Mapping dictionary
# :rtype: Dict[Tuple[int, NodeID], int]
# """
# return dict(self._node_map)
# @property
# def k(self) -> int:
# """Return the number of snapshots fused.
# :returns: Snapshot count
# :rtype: int
# """
# return self._k
# import networkx as nx
# from collections import defaultdict
# from typing import Dict, List, Tuple, Any, Set, Union
# # -----------------------------------------------------------------------------
# # Type aliases
# # -----------------------------------------------------------------------------
# NodeID = int
# Order = Tuple[float, float]
# Node = Tuple[NodeID, Dict[str, Any]]
# Edge = Tuple[NodeID, NodeID, Dict[str, Any]]
# __all__ = ["MTG"]
# class MTG:
# """Fuse two molecular graphs via a pair‑groupoid edge‑composition rule.
# Parameters
# ----------
# G1, G2
# Input :class:`networkx.Graph` (or *DiGraph*) objects. Nodes must carry an
# ``"element"`` attribute; edges carry an ``"order"`` 2‑tuple *(x, y)*.
# mapping
# A partial node map **G1 → G2** indicating which atoms are chemically
# identical (intersection). Keys are node IDs in *G1*, values in *G2*.
# Notes
# -----
# 1. ``intersection_ids`` are created where mapping ``G1[i] → G2[j]``.
# 2. Edges are inserted in two passes:
# * *Pass 1* – all edges from *G1* are copied unchanged.
# * *Pass 2* – edges from *G2* are remapped; when both endpoints are in
# ``intersection_ids`` **and** their bond orders satisfy the *pair‐
# groupoid* condition
# ``(a₁, a₂) + (b₁, b₂) with a₂ == b₁ → (a₁, b₂)``,
# the edges are *composed* instead of duplicated.
# Examples
# --------
# >>> mtg = MTG(G1, G2, {1: 3, 4: 6, 5: 1})
# >>> fused = mtg.get_graph()
# >>> fused.nodes(data=True)
# ...
# """
# # ------------------------------------------------------------------
# # Construction helpers
# # ------------------------------------------------------------------
# def __init__(
# self,
# G1: Union[nx.Graph, nx.DiGraph],
# G2: Union[nx.Graph, nx.DiGraph],
# mapping: Dict[NodeID, NodeID],
# ) -> None:
# # Store originals
# self.G1 = G1
# self.G2 = G2
# self.mapping12 = mapping # G1 → G2
# # ---- 1. Build fused node set ---------------------------------
# (
# self.product_nodes, # list[(id, attrs)] in fused graph
# self.map1, # G1 id → fused id
# self.map2, # G2 id → fused id
# self.intersection_ids, # list[fused id]
# ) = self._fuse_nodes()
# # ---- 2. Fuse edges with groupoid rule ------------------------
# fused_edges_step1 = self._insert_edges_from(self.G1.edges(data=True), self.map1)
# self.product_edges = self._insert_edges_from(
# self.G2.edges(data=True), self.map2, existing=fused_edges_step1
# )
# # ------------------------------------------------------------------
# # Node fusion
# # ------------------------------------------------------------------
# def _fuse_nodes(self):
# merged: Dict[NodeID, Dict[str, Any]] = {}
# map1: Dict[NodeID, NodeID] = {}
# map2: Dict[NodeID, NodeID] = {}
# used: Set[NodeID] = set()
# # --- copy G1 directly into fused graph ------------------------
# for v, attrs in self.G1.nodes(data=True):
# merged[v] = attrs.copy()
# map1[v] = v
# used.add(v)
# # inverse mapping: G2 node → G1 node it merges to
# inv_map = {g2: g1 for g1, g2 in self.mapping12.items()}
# intersection: List[NodeID] = []
# # --- process G2 nodes -----------------------------------------
# next_id = max(used) + 1 if used else 0
# for v, attrs in self.G2.nodes(data=True):
# if v in inv_map: # merged node
# tgt = inv_map[v]
# map2[v] = tgt
# intersection.append(tgt)
# else: # unique node from G2
# while next_id in used:
# next_id += 1
# merged[next_id] = attrs.copy()
# map2[v] = next_id
# used.add(next_id)
# next_id += 1
# nodes_sorted = sorted(merged.items()) # list[(id, attrs)]
# return nodes_sorted, map1, map2, intersection
# # ------------------------------------------------------------------
# # Edge insertion & groupoid composition
# # ------------------------------------------------------------------
# def _insert_edges_from(
# self, edge_iter, node_map: Dict[NodeID, NodeID], existing: List[Edge] = None
# ) -> List[Edge]:
# """Insert edges into *existing* applying the groupoid rule when
# possible."""
# existing = [] if existing is None else existing.copy()
# # Remap and append new edges
# for u, v, attrs in edge_iter:
# u3 = node_map[u]
# v3 = node_map[v]
# existing.append((u3, v3, attrs.copy()))
# # Canonicalize keys for undirected graphs
# def key(u, v):
# return (u, v) if isinstance(self.G1, nx.DiGraph) else tuple(sorted((u, v)))
# # Group edges by (u,v)
# buckets: Dict[Tuple[NodeID, NodeID], List[Order]] = defaultdict(list)
# bucket_src: Dict[Tuple[NodeID, NodeID], List[str]] = defaultdict(list)
# for idx, (u, v, attrs) in enumerate(existing):
# buckets[key(u, v)].append(tuple(attrs["order"]))
# bucket_src[key(u, v)].append("G1" if idx < len(self.G1.edges) else "G2")
# fused_edges: List[Edge] = []
# for (u, v), orders in buckets.items():
# # src = bucket_src[(u, v)]
# if (
# u in self.intersection_ids
# and v in self.intersection_ids
# and len(orders) >= 2
# ):
# # Attempt pair‑wise composition between G1 (first) and any G2 edge
# made_composite = False
# for idx2, ord2 in enumerate(orders[1:], start=1):
# a1, a2 = orders[0]
# b1, b2 = ord2
# if a2 == b1:
# fused_edges.append((u, v, {"order": (a1, b2)}))
# made_composite = True
# break
# if not made_composite:
# # fall back to *all* distinct orders
# for ord_ in orders:
# fused_edges.append((u, v, {"order": ord_}))
# else:
# for ord_ in orders:
# fused_edges.append((u, v, {"order": ord_}))
# return self._dedupe_edges(fused_edges)
# # ------------------------------------------------------------------
# @staticmethod
# def _dedupe_edges(edges: List[Edge]) -> List[Edge]:
# seen: Set[Tuple[int, int, Order]] = set()
# out: List[Edge] = []
# for u, v, attrs in edges:
# key = (min(u, v), max(u, v), tuple(attrs["order"]))
# if key not in seen:
# seen.add(key)
# out.append((u, v, attrs))
# return out
# # ------------------------------------------------------------------
# # Public helpers
# # ------------------------------------------------------------------
# def get_nodes(self) -> List[Node]:
# """List of `(id, attrs)` for fused graph."""
# return self.product_nodes
# def get_edges(self) -> List[Edge]:
# """List of `(u, v, attrs)` for fused graph."""
# return self.product_edges
# def get_map1(self) -> Dict[NodeID, NodeID]:
# return self.map1
# def get_map2(self) -> Dict[NodeID, NodeID]:
# return self.map2
# def get_graph(self, *, directed: bool = False):
# G = nx.DiGraph() if directed else nx.Graph()
# G.add_nodes_from(self.product_nodes)
# for u, v, attrs in self.product_edges:
# o = attrs["order"]
# attrs["standard_order"] = o[0] - o[1] if None not in o else None
# G.add_edge(u, v, **attrs)
# return G
# # ------------------------------------------------------------------
# def __repr__(self):
# return f"MTG(|V|={len(self.product_nodes)}, |E|={len(self.product_edges)})"