Source code for synkit.Graph.Matcher.subgraph_matcher

from __future__ import annotations

"""
subgraph_matcher.py
================================
A **lean**, **typed**, and **high-performance** successor to the original
sub-graph matching utilities in SynKit.

Key Features
------------
• **Speed**
  • Element-multiset, node-attribute, degree-histogram, and WL-1 hashing
    pre-filters remove up to 95 % of impossible host-CCs before VF2.
  • Heuristic CC ordering and optional result limits prune the search tree.

• **Flexibility**
  • Three matching strategies:
    – ALL: classic VF2 over the entire host
    – COMPONENT: CC-aware, distinct-CC enforcement
    – BACKTRACK: component-aware with classic-fallback
  • Fallback to brute-force VF2 when host has fewer CCs than pattern
  • Optional `strict_cc_count` to enforce exact CC counts

• **Safety & Cleanliness**
  • Never mutates your input graphs
  • Validates inputs and raises clear errors on misuse
  • Full type annotations throughout
  • Single public entry point:
      `SubgraphSearchEngine.find_subgraph_mappings(...)`

Public API
----------
SubgraphSearchEngine.find_subgraph_mappings(
    host: nx.Graph,
    pattern: nx.Graph,
    node_attrs: List[str],
    edge_attrs: List[str],
    strategy: Strategy = Strategy.COMPONENT,
    *,
    max_results: Optional[int] = None,
    strict_cc_count: bool = False,
    wl1_filter: bool = True,
) -> List[MappingDict]

    Dispatches to one of:
      - `_all_monomorphisms` (classic VF2)
      - `_component_aware_mappings` (fast, CC-aware)
      - `_bt_subgraph_mappings` (with fallback)

Helper Functions
----------------
- `wl1_hash(graph, node_attrs)`
  Computes a single-pass Weisfeiler–Lehman coloring signature.
- `_all_monomorphisms(host, pattern, node_attrs, edge_attrs)`
  Fast wrapper around NetworkX’s VF2 that returns every subgraph monomorphism.
- `_component_aware_mappings(...)`
  Splits graphs into connected components (CCs), applies multi-level filters
  (element, attribute, degree, WL-1), then assembles only those mappings
  placing each pattern-CC into a distinct host-CC.
- `_bt_subgraph_mappings(...)`
  Same CC-aware logic but falls back to classic VF2 if any CC can’t embed.

Usage Example
-------------
```python
from subgraph_matcher import SubgraphSearchEngine, Strategy

mappings = SubgraphSearchEngine.find_subgraph_mappings(
    host_graph,
    pattern_graph,
    node_attrs=["element", "aromatic"],
    edge_attrs=["order"],
    strategy=Strategy.COMPONENT,
    max_results=50,
    strict_cc_count=False,
)
"""

from typing import Any, Dict, List, Set, Optional, Sequence, Tuple, Callable, Union
from operator import eq
import networkx as nx

from networkx.algorithms.isomorphism import GraphMatcher
from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match

from synkit.Synthesis.Reactor.strategy import Strategy

try:
    from mod import ruleGMLString

    _RULE_AVAILABLE = True
except ImportError:
    ruleGMLString = None  # type: ignore[assignment]
    _RULE_AVAILABLE = False

# ---------------------------------------------------------------------------
# Type aliases
# ---------------------------------------------------------------------------
EdgeAttr = Dict[str, Any]
MappingDict = Dict[int, int]

__all__: Sequence[str] = [
    "SubgraphMatch",
    "SubgraphSearchEngine",
]


def electron_aware_node_match(
    host_data: EdgeAttr,
    pattern_data: EdgeAttr,
    node_attrs: Sequence[str],
) -> bool:
    """Compare node attributes with chemistry-aware cardinality semantics.

    Attributes in ``node_attrs`` are exact matches except:

    - ``hcount``: host must be greater than or equal to pattern
    - ``lone_pairs``: host must be greater than or equal to pattern
    - ``aromatic_n_pi_count``: exact aromatic-N role label when present
    ``radical`` therefore remains exact whenever the caller includes it in
    ``node_attrs``.
    """
    for attr in node_attrs:
        host_value = host_data.get(
            attr, 0 if attr in {"hcount", "lone_pairs"} else None
        )
        pattern_value = pattern_data.get(
            attr, 0 if attr in {"hcount", "lone_pairs"} else None
        )
        if attr in {"hcount", "lone_pairs"}:
            if host_value < pattern_value:
                return False
            continue
        if host_value != pattern_value:
            return False
    return True


def electron_aware_edge_match(
    host_data: EdgeAttr,
    pattern_data: EdgeAttr,
    edge_attrs: Sequence[str],
) -> bool:
    """Compare edge attrs while treating aromatic Kekule phase as non-semantic.

    Aromatic presentation bonds are matched by ``order == 1.5``. Their
    particular ``sigma_order`` / ``pi_order`` split depends on the chosen
    Kekule form and is not stable across independently parsed graphs.
    """
    host_is_aromatic = host_data.get("order") == 1.5
    pattern_is_aromatic = pattern_data.get("order") == 1.5
    for attr in edge_attrs:
        if (
            attr in {"sigma_order", "pi_order"}
            and host_is_aromatic
            and pattern_is_aromatic
        ):
            continue
        if host_data.get(attr) != pattern_data.get(attr):
            return False
    return True


def explain_node_mismatch(
    host_data: EdgeAttr,
    pattern_data: EdgeAttr,
    node_attrs: Sequence[str],
) -> list[str]:
    """Return node-level mismatch reasons using matcher semantics."""
    reasons: list[str] = []
    for attr in node_attrs:
        host_value = host_data.get(
            attr, 0 if attr in {"hcount", "lone_pairs"} else None
        )
        pattern_value = pattern_data.get(
            attr, 0 if attr in {"hcount", "lone_pairs"} else None
        )
        if attr in {"hcount", "lone_pairs"}:
            if host_value < pattern_value:
                reasons.append(f"{attr}: host {host_value} < pattern {pattern_value}")
            continue
        if host_value != pattern_value:
            reasons.append(f"{attr}: host {host_value!r} != pattern {pattern_value!r}")
    return reasons


def resolve_template_match_attrs(
    pattern: nx.Graph,
    *,
    legacy_node_attrs: Sequence[str] = ("element", "charge"),
    legacy_edge_attrs: Sequence[str] = ("order",),
) -> tuple[list[str], list[str]]:
    """Choose match attrs from what the template actually carries.

    Legacy templates keep the legacy attribute set. Electron-aware templates opt
    into extra constraints only when those attrs are present on the template.
    """
    node_attrs = list(legacy_node_attrs)
    edge_attrs = list(legacy_edge_attrs)

    for attr in (
        "aromatic",
        "hcount",
        "lone_pairs",
        "radical",
        "aromatic_n_pi_count",
    ):
        if any(attr in data for _, data in pattern.nodes(data=True)):
            node_attrs.append(attr)

    for attr in ("sigma_order", "pi_order"):
        if any(attr in data for _, _, data in pattern.edges(data=True)):
            edge_attrs.append(attr)

    return node_attrs, edge_attrs


def diagnose_candidate_node_match(
    host_data: EdgeAttr,
    pattern_data: EdgeAttr,
    node_attrs: Sequence[str],
) -> dict[str, Any]:
    """Return a compact node-match diagnostic payload."""
    reasons = explain_node_mismatch(host_data, pattern_data, node_attrs)
    return {"matched": not reasons, "reasons": reasons}


# ---------------------------------------------------------------------------
# Core engine class
# ---------------------------------------------------------------------------
[docs] class SubgraphMatch: """Boolean-only checks for graph isomorphism and subgraph (induced or monomorphic) matching. Provides static methods for NetworkX-based checks and optional GML "rule" backend. """ @staticmethod def _get_edge_labels(graph: Any) -> list: """Extracts the bond types (edge labels) from a given graph. Parameters: - graph: The graph object containing the edges. Returns: - list: List of edge labels as strings. """ return [str(e.bondType) for e in graph.edges] @staticmethod def _get_node_labels(graph: Any) -> list: """Extracts the atom IDs (node labels) from a given graph. Parameters: - graph: The graph object containing the vertices. Returns: - list: List of node labels as strings. """ return [str(v.atomId) for v in graph.vertices]
[docs] @staticmethod def rule_subgraph_morphism( rule_1: str, rule_2: str, use_filter: bool = False ) -> bool: """Evaluates if two GML-formatted rule representations are isomorphic or one is a subgraph of the other (monomorphic). Parameters: - rule_1 (str): GML string of the first rule. - rule_2 (str): GML string of the second rule. - use_filter (bool, optional): Whether to filter by node/edge labels and vertex counts. Returns: - bool: True if the monomorphism condition is met, False otherwise. """ try: rule_obj_1 = ruleGMLString(rule_1, add=False) rule_obj_2 = ruleGMLString(rule_2, add=False) except Exception as e: raise Exception(f"Error parsing GML strings: {e}") if use_filter: if rule_obj_1.context.numVertices > rule_obj_2.context.numVertices: return False node_1_left = SubgraphMatch._get_node_labels(rule_obj_1.left) node_2_left = SubgraphMatch._get_node_labels(rule_obj_2.left) edge_1_left = SubgraphMatch._get_edge_labels(rule_obj_1.left) edge_2_left = SubgraphMatch._get_edge_labels(rule_obj_2.left) if not all(node in node_2_left for node in node_1_left): return False if not all(edge in edge_2_left for edge in edge_1_left): return False return rule_obj_1.monomorphism(rule_obj_2) == 1
[docs] @staticmethod def subgraph_isomorphism( child_graph: nx.Graph, parent_graph: nx.Graph, node_label_names: List[str] = ["element", "charge"], node_label_default: List[Any] = ["*", 0], edge_attribute: str = "order", use_filter: bool = False, check_type: str = "induced", # 'induced' or 'monomorphism' node_comparator: Optional[Callable[[Any, Any], bool]] = None, edge_comparator: Optional[Callable[[Any, Any], bool]] = None, ) -> bool: """Enhanced checks if the child graph is a subgraph isomorphic to the parent graph based on customizable node and edge attributes.""" if use_filter: if ( child_graph.number_of_nodes() > parent_graph.number_of_nodes() or child_graph.number_of_edges() > parent_graph.number_of_edges() ): return False for _, child_data in child_graph.nodes(data=True): found_match = False for _, parent_data in parent_graph.nodes(data=True): match = True for label, default in zip(node_label_names, node_label_default): if child_data.get(label, default) != parent_data.get( label, default ): match = False break if match: found_match = True break if not found_match: return False if edge_attribute: for u, v, child_data in child_graph.edges(data=True): if not parent_graph.has_edge(u, v): return False parent_data = parent_graph[u][v] child_order = child_data.get(edge_attribute) parent_order = parent_data.get(edge_attribute) if isinstance(child_order, tuple) and isinstance( parent_order, tuple ): if child_order != parent_order: return False elif child_order != parent_order: return False node_comparator = node_comparator or eq edge_comparator = edge_comparator or eq node_match = generic_node_match( node_label_names, node_label_default, [node_comparator] * len(node_label_names), ) edge_match = generic_edge_match(edge_attribute, None, edge_comparator) matcher = GraphMatcher( parent_graph, child_graph, node_match=node_match, edge_match=edge_match ) if check_type == "induced": return matcher.subgraph_is_isomorphic() else: return matcher.subgraph_is_monomorphic()
[docs] @staticmethod def is_subgraph( pattern: Union[nx.Graph, str], host: Union[nx.Graph, str], node_label_names: List[str] = ["element", "charge"], node_label_default: List[Any] = ["*", 0], edge_attribute: str = "order", use_filter: bool = False, check_type: str = "induced", backend: str = "nx", ) -> bool: """Unified API for subgraph/isomorphism either via NX or GML backend.""" if backend == "nx": return SubgraphMatch.subgraph_isomorphism( pattern, host, node_label_names, node_label_default, edge_attribute, use_filter, check_type, ) if backend == "mod": if not _RULE_AVAILABLE: raise ImportError("GML rule backend not installed – pip install mod.") return SubgraphMatch.rule_subgraph_morphism( pattern, host, use_filter=use_filter ) raise ValueError(f"Unknown backend: {backend}")
# ----------------------------------------------------------------------------- # Sub‑graph search engine # -----------------------------------------------------------------------------
[docs] class SubgraphSearchEngine: """Static helper routines for sub-graph monomorphism search. :cvar DEFAULT_THRESHOLD: default cap on embedding enumeration (5000) """ DEFAULT_THRESHOLD: int = 5_000 @staticmethod def _quick_pre_filter( host: nx.Graph, pattern: nx.Graph, node_attrs: List[str], threshold: int, ) -> bool: """Estimate if candidate-product exceeds threshold with degree pruning. We refine the basic Cartesian-product by requiring each host candidate to match node attributes *and* have degree ≥ the pattern node’s degree. This tighter filter greatly reduces false positives (over-pruning). """ estimate = 1 # Pre-compute pattern degrees pat_degrees = {n: pattern.degree(n) for n in pattern.nodes()} for p_node, pat_data in pattern.nodes(data=True): pat_deg = pat_degrees[p_node] # count host nodes matching attributes and degree count = sum( 1 for _, host_data in host.nodes(data=True) if electron_aware_node_match(host_data, pat_data, node_attrs) and host.degree(_) >= pat_deg ) # if no candidates; impossible match if count == 0: return True estimate *= count if estimate > threshold * 1e4: # reduce false positives return True return False
[docs] @staticmethod def find_subgraph_mappings( host: nx.Graph, pattern: nx.Graph, *, node_attrs: List[str], edge_attrs: List[str], strategy: Union[str, Strategy] = Strategy.COMPONENT, max_results: Optional[int] = None, strict_cc_count: bool = True, threshold: Optional[int] = None, pre_filter: bool = False, ) -> List[MappingDict]: """Dispatch to a subgraph-matching strategy with optional guards. Parameters ---------- host, pattern NetworkX graphs (host ≥ pattern). node_attrs, edge_attrs Keys of attributes to match; ``hcount`` and ``lone_pairs`` use host-greater-or-equal semantics, while the rest are exact. strategy Matching strategy code or enum ("all", "comp", "bt"). max_results Stop after this many embeddings (None = no limit). strict_cc_count If True, host CC count must ≤ pattern CC count for COMPONENT/BACKTRACK. threshold Override the default cap (DEFAULT_THRESHOLD) on embeddings. pre_filter If True, run a cheap Cartesian-product pre-filter against the threshold. Returns ------- List of dictionaries mapping pattern node→host node. Empty if none or if any guard (pre-filter or enumeration) exceeds the threshold. """ strat = Strategy.from_string(strategy) if strat is Strategy.PARTIAL: raise NotImplementedError("PARTIAL strategy not implemented yet.") # determine effective threshold thresh = ( threshold if threshold is not None else SubgraphSearchEngine.DEFAULT_THRESHOLD ) # defensive copies host = host.copy() pattern = pattern.copy() # quick pre-filter if pre_filter and SubgraphSearchEngine._quick_pre_filter( host, pattern, node_attrs, thresh ): return [] # dispatch if strat is Strategy.ALL: results = SubgraphSearchEngine._find_all_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, thresh ) elif strat is Strategy.COMPONENT: results = SubgraphSearchEngine._find_component_aware_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, strict_cc_count, thresh, ) else: # BACKTRACK results = SubgraphSearchEngine._find_bt_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, strict_cc_count, thresh, ) # final threshold guard return [] if len(results) > thresh else results
@staticmethod def _find_all_subgraph_mappings( host: nx.Graph, pattern: nx.Graph, node_attrs: List[str], edge_attrs: List[str], max_results: Optional[int], threshold: int, ) -> List[MappingDict]: """Classic VF2 over the whole host graph.""" def node_match(nh: EdgeAttr, np: EdgeAttr) -> bool: return electron_aware_node_match(nh, np, node_attrs) def edge_match(eh: EdgeAttr, ep: EdgeAttr) -> bool: return electron_aware_edge_match(eh, ep, edge_attrs) gm = GraphMatcher(host, pattern, node_match=node_match, edge_match=edge_match) results: List[MappingDict] = [] for iso in gm.subgraph_monomorphisms_iter(): results.append({p: h for h, p in iso.items()}) if max_results and len(results) >= max_results: break if len(results) > threshold: return [] return results @staticmethod def _find_component_aware_subgraph_mappings( host: nx.Graph, pattern: nx.Graph, node_attrs: List[str], edge_attrs: List[str], max_results: Optional[int], strict_cc_count: bool, threshold: int, ) -> List[MappingDict]: """Component-aware VF2 split by connected components.""" host_ccs = [host.subgraph(c).copy() for c in nx.connected_components(host)] pat_ccs = [pattern.subgraph(c).copy() for c in nx.connected_components(pattern)] hcc, pcc = len(host_ccs), len(pat_ccs) if pcc == 0: return [{}] if hcc < pcc: return SubgraphSearchEngine._find_all_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, threshold ) if hcc > pcc and strict_cc_count: return [] def node_match(nh: EdgeAttr, np: EdgeAttr) -> bool: return electron_aware_node_match(nh, np, node_attrs) def edge_match(eh: EdgeAttr, ep: EdgeAttr) -> bool: return electron_aware_edge_match(eh, ep, edge_attrs) per_cc: List[List[Tuple[int, MappingDict]]] = [] for pc in pat_ccs: sz = pc.number_of_nodes() cand = [i for i, hc in enumerate(host_ccs) if hc.number_of_nodes() >= sz] if not cand: return [] maps: List[Tuple[int, MappingDict]] = [] for i in cand: gm = GraphMatcher( host_ccs[i], pc, node_match=node_match, edge_match=edge_match ) for iso in gm.subgraph_monomorphisms_iter(): maps.append((i, {p: h for h, p in iso.items()})) if max_results and len(maps) >= max_results: break if len(maps) > threshold: return [] if max_results and len(maps) >= max_results: break if not maps: return [] per_cc.append(maps) order = sorted(range(pcc), key=lambda i: len(per_cc[i])) ordered = [per_cc[i] for i in order] results: List[MappingDict] = [] used: Set[int] = set() def backtrack(level: int, acc: MappingDict): if max_results and len(results) >= max_results: return if len(results) > threshold: return if level == pcc: results.append(acc.copy()) return for hi, m in ordered[level]: if hi in used or any(p in acc for p in m): continue used.add(hi) acc.update(m) backtrack(level + 1, acc) for p in m: acc.pop(p) used.remove(hi) if max_results and len(results) >= max_results: return if len(results) > threshold: return backtrack(0, {}) return results @staticmethod def _find_bt_subgraph_mappings( host: nx.Graph, pattern: nx.Graph, node_attrs: List[str], edge_attrs: List[str], max_results: Optional[int], strict_cc_count: bool, threshold: int, ) -> List[MappingDict]: primary = SubgraphSearchEngine._find_component_aware_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, strict_cc_count, threshold, ) if primary: return primary return SubgraphSearchEngine._find_all_subgraph_mappings( host, pattern, node_attrs, edge_attrs, max_results, threshold ) def __repr__(self) -> str: return "<SubgraphSearchEngine – use `find_subgraph_mappings`>" __str__ = __repr__ # helpful alias for interactive users -------------------------------- @property def help(self) -> str: # noqa: D401 – property for convenience """Return the full module docstring.""" return __doc__