Source code for synkit.Graph.MTG.mcs_matcher

"""mcs_matcher.py — Maximum/Common Subgraph Matcher
=================================================

A convenience wrapper around ``networkx.algorithms.isomorphism.GraphMatcher``
that finds *all* common-subgraph (or maximum-common-subgraph) node mappings
between two molecular graphs.

Highlights
----------
* **Flexible node matching** via ``generic_node_match``.
* **Scalar edge attribute** comparison (e.g. ``order``).
* Results are **cached** – call :py:meth:`get_mappings` to retrieve them.
* Helpful ``help()`` and ``__repr__`` utilities inspired by the MTG API style.

Public API
~~~~~~~~~~
``MCSMatcher(node_label_names, node_label_defaults, edge_attribute='order', allow_shift=True)``
    Construct a matcher instance.

``matcher.find_common_subgraph(G1, G2, mcs=False, mcs_mol=False)``
    Run the search (stores but does *not* return mappings). If ``mcs_mol=True``,
    find mappings by matching entire connected components (largest molecules).

``matcher.get_mappings()``
    Retrieve the stored mapping list.

``matcher.find_rc_mapping(rc1, rc2, mcs=False)``
    Convenience wrapper for ITS‐reaction‑centre objects (via ``its_decompose``).

Dependencies
~~~~~~~~~~~~
* Python 3.9+
* NetworkX ≥ 3.0
* ``synkit.Graph.ITS.its_decompose`` (optional helper)
"""

from __future__ import annotations

import itertools
from typing import Dict, List, Callable, Optional, Any, Set

import networkx as nx
from networkx.algorithms.isomorphism import GraphMatcher, generic_node_match

try:
    from synkit.Graph.ITS import its_decompose  # optional
except ImportError:  # pragma: no cover – allow standalone use
    its_decompose = None  # type: ignore

__all__ = ["MCSMatcher"]


[docs] class MCSMatcher: """Common / maximum‑common subgraph matcher. Parameters ---------- node_label_names : list[str], optional Node attribute keys to compare (default ``["element"]``). node_label_defaults : list[Any], optional Fallback values when an attribute is missing (default ``["*"]``). edge_attribute : str, optional Edge attribute storing the scalar *order* (default ``"order"``). allow_shift : bool, optional Placeholder for future asymmetric rules (ignored for scalars). """ def __init__( self, node_label_names: Optional[List[str]] | None = None, node_label_defaults: Optional[List[Any]] | None = None, edge_attribute: str = "order", allow_shift: bool = True, ) -> None: if node_label_names is None: node_label_names = ["element"] if node_label_defaults is None: node_label_defaults = ["*"] * len(node_label_names) self.node_match: Callable[[Dict[str, Any], Dict[str, Any]], bool] = ( generic_node_match( node_label_names, node_label_defaults, [lambda x, y: x == y] * len(node_label_names), ) ) self.edge_attr = edge_attribute self.allow_shift = allow_shift # internal cache self._mappings: List[Dict[int, int]] = [] self._last_size: int = 0 def _edge_match( self, host_attrs: Dict[str, Any], pat_attrs: Dict[str, Any] ) -> bool: """Compare scalar *order* attributes (exact equality).""" try: return float(host_attrs.get(self.edge_attr)) == float( pat_attrs.get(self.edge_attr) ) except (TypeError, ValueError): return False @staticmethod def _invert_mapping(gm_mapping: Dict[int, int]) -> Dict[int, int]: """Convert *host→pattern* dict to *pattern→host*.""" return {pat: host for host, pat in gm_mapping.items()} def _find_mcs_mol(self, G1: nx.Graph, G2: nx.Graph) -> Dict[int, int]: """ Match connected components of G1 to G2 of the same size, combining each component's isomorphic mapping into one dict. """ # sort components by size descending comps1 = sorted(nx.connected_components(G1), key=len, reverse=True) comps2 = sorted(nx.connected_components(G2), key=len, reverse=True) used2: Set[frozenset[int]] = set() combined: Dict[int, int] = {} for comp1 in comps1: size = len(comp1) sub1 = G1.subgraph(comp1) for comp2 in comps2: if len(comp2) != size: continue key2 = frozenset(comp2) if key2 in used2: continue sub2 = G2.subgraph(comp2) gm = GraphMatcher( sub1, sub2, node_match=self.node_match, edge_match=self._edge_match, ) if gm.is_isomorphic(): combined.update(gm.mapping) used2.add(key2) break return combined
[docs] def find_common_subgraph( self, G1: nx.Graph, G2: nx.Graph, *, mcs: bool = False, mcs_mol: bool = False, ) -> None: """Search for subgraph isomorphisms and cache the mappings. Parameters ---------- G1 : nx.Graph - *pattern* graph (searched as a subgraph) G2 : nx.Graph - *host* graph mcs : bool, optional If *True*, keep only mappings of maximum size. mcs_mol : bool, optional If *True*, match entire connected components (largest molecules). """ self._mappings.clear() self._last_size = 0 if mcs_mol: combined = self._find_mcs_mol(G1, G2) self._mappings = [combined] self._last_size = len(combined) return max_k = min(len(G1), len(G2)) sizes = range(max_k, 0, -1) seen: Set[tuple] = set() for k in sizes: if mcs and self._last_size and k < self._last_size: break # already found maximum size level_found = False for nodes in itertools.combinations(G1.nodes(), k): subG = G1.subgraph(nodes).copy() gm = GraphMatcher( G2, subG, node_match=self.node_match, edge_match=self._edge_match, ) for iso in gm.subgraph_isomorphisms_iter(): inv = self._invert_mapping(iso) key = tuple(sorted(inv.items())) if key in seen: continue seen.add(key) self._mappings.append(inv) level_found = True if level_found: self._last_size = k if mcs: break # done – maximum size reached # retain only maximum‑size mappings if requested if mcs and self._last_size: self._mappings = [m for m in self._mappings if len(m) == self._last_size] # final ordering – largest first then lexicographic self._mappings.sort(key=lambda d: (-len(d), tuple(sorted(d.items()))))
[docs] def find_rc_mapping( self, rc1, rc2, *, mcs: bool = False, mcs_mol: bool = False, ) -> None: # type: ignore[override] if its_decompose is None: raise ImportError( "synkit is not available; cannot decompose reaction centres." ) _, r1 = its_decompose(rc1) l2, _ = its_decompose(rc2) self.find_common_subgraph(r1, l2, mcs=mcs, mcs_mol=mcs_mol)
[docs] def get_mappings(self) -> List[Dict[int, int]]: """Return the cached mapping list (empty if `find_*` not yet called).""" return self._mappings.copy()
@property def last_size(self) -> int: """Number of nodes in the most recent mapping set (0 if none).""" return self._last_size def __repr__(self) -> str: # noqa: D401 return ( f"MCSMatcher(mappings={len(self._mappings)}, last_size={self._last_size})" )
[docs] def help(self) -> None: # noqa: D401 """Print class docstring and public methods.""" print(self.__doc__) for name in dir(self): if not name.startswith("_"): print(name)