Source code for synkit.Graph.Matcher.auto_est

"""
auto_est.py
~~~~~~~~~~~

Approximate node automorphism groups (orbits) via 1-WL color refinement,
plus orbit- and component-aware deduplication utilities.

Design goals (SynKit style)
---------------------------
- OOP with a scikit-like ``fit() -> self``.
- Deterministic output ordering.
- Sphinx-style docstrings.
- Helper methods and useful properties.
- Optional "components style" grouping, analogous to exact Automorphism:
  you can obtain *orbit-components* induced by a subset of nodes (anchors).

Important note
--------------
WL-1 provides an *approximate* orbit partition: distinct WL colors imply
distinct orbits, but equal WL colors do not guarantee true symmetry.

This module offers:
- ``AutoEst.orbits``: WL-equivalence classes on the given graph
- ``AutoEst.components(nodes)``: connected components on an induced subgraph
- ``AutoEst.orbit_components(nodes)``: components of the orbit-quotient graph
- ``AutoEst.deduplicate_host_orbits(mappings)``: host-orbit based pruning
- ``AutoEst.deduplicate_pattern_orbits(mappings, pattern_orbits, ...)``:
  pattern-orbit based pruning (anchor-aware)

The anchor-aware deduplication follows your recent constraint:
anchor components must not be pruned by orbit-independence.
"""

from __future__ import annotations

from dataclasses import dataclass
from itertools import groupby
from typing import (
    Any,
    Dict,
    FrozenSet,
    Hashable,
    Iterable,
    List,
    Mapping,
    Optional,
    Tuple,
)

import networkx as nx


# --------------------------------------------------------------------------- #
# Small, typed config container (keeps complexity low)
# --------------------------------------------------------------------------- #
@dataclass(frozen=True)
class _WLConfig:
    """
    Internal WL configuration.

    :param node_attrs: Node attribute keys for initial labels.
    :type node_attrs: tuple[str, ...]
    :param edge_attrs: Edge attribute keys for neighbor signatures.
    :type edge_attrs: tuple[str, ...]
    :param max_iter: Maximum refinement iterations.
    :type max_iter: int
    """

    node_attrs: Tuple[str, ...]
    edge_attrs: Tuple[str, ...]
    max_iter: int


# --------------------------------------------------------------------------- #
# Core estimator
# --------------------------------------------------------------------------- #
[docs] class AutoEst: """ Approximate node automorphism groups (orbits) via 1-WL color refinement. This class performs a Weisfeiler–Lehman (WL-1) style color refinement on the input graph to approximate a partition of nodes into *automorphism-indistinguishability* classes (often called “WL-orbits”). In many practical graphs (especially with chemically meaningful node/edge labels), the WL partition coincides with, or closely approximates, the true orbit partition and is much cheaper than enumerating automorphisms. Besides the basic orbit partition, :class:`AutoEst` provides a “components-style” interface analogous to the exact automorphism helper: - :meth:`components` returns connected components of an induced subgraph (useful for “anchor components”). - :meth:`orbit_components` returns connected components in the **orbit-quotient graph** restricted to a node subset, capturing which orbits are *coupled* by edges inside the subset. - :meth:`deduplicate_host_orbits` prunes mappings by **host-side** WL-orbits. - :meth:`deduplicate_pattern_orbits` prunes mappings by **pattern** orbits, with optional anchor-aware behavior (no pruning inside anchored nodes). :param graph: Input NetworkX graph. It is not modified in-place. :type graph: nx.Graph :param node_attrs: Node attribute keys whose values will be included in the initial coloring. If ``None``, defaults are used. :type node_attrs: list[str] or None :param edge_attrs: Edge attribute keys whose values will be incorporated into the neighborhood signatures. If ``None``, defaults are used. :type edge_attrs: list[str] or None :param max_iter: Maximum number of WL refinement iterations. :type max_iter: int .. note:: This is an **approximate** estimator of automorphism orbits: two nodes with different final WL colors cannot be in the same orbit, but nodes with the same WL color might still be distinguishable by higher-order invariants (e.g. higher-dimensional WL, spectral invariants, or full automorphism search). In many molecular graphs where node/edge labels are informative, this partition is typically very close to the true automorphism partition and is often sufficient for symmetry-aware pruning. .. note:: “Anchor components” can invalidate orbit-wise independence: when a subset of nodes is treated as anchored, orbits that are connected through the anchored subgraph should be considered *coupled*. Use :meth:`orbit_components` and anchor-aware :meth:`deduplicate_pattern_orbits` to avoid incorrect pruning. .. seealso:: For discussions relating Weisfeiler–Lehman refinement to automorphism indistinguishability and orbit structure, see: * A. Dawar and G. Vagnozzi, *Generalizations of k-dimensional Weisfeiler–Leman stabilization*, arXiv preprint (2019/2020). Example ------- .. code-block:: python import networkx as nx from synkit.Graph.automorphism import AutoEst # Simple 4-cycle where all nodes are symmetric under rotation/reflection G = nx.cycle_graph(4) est = AutoEst(G, node_attrs=[], edge_attrs=[]) est = est.fit() print(est.orbits) # [frozenset({0, 1, 2, 3})] print(est.n_orbits) # 1 # "components style": connected components of an induced subgraph comps = est.components(nodes=[0, 1]) print(comps) # [frozenset({0, 1})] # orbit-quotient components of an induced subgraph oc = est.orbit_components(nodes=[0, 1, 2, 3]) print(oc) # [frozenset({0})] # one orbit-id component in this symmetric case """ _DEF_NODE_ATTRS: Tuple[str, ...] = ("element", "charge") _DEF_EDGE_ATTRS: Tuple[str, ...] = ("order",) def __init__( self, graph: nx.Graph, node_attrs: Optional[List[str]] = None, edge_attrs: Optional[List[str]] = None, max_iter: int = 10, ) -> None: self._graph: nx.Graph = graph cfg = _WLConfig( node_attrs=( tuple(node_attrs) if node_attrs is not None else self._DEF_NODE_ATTRS ), edge_attrs=( tuple(edge_attrs) if edge_attrs is not None else self._DEF_EDGE_ATTRS ), max_iter=int(max_iter), ) self._cfg: _WLConfig = cfg self._colors: Dict[Hashable, int] = {} self._orbits: List[FrozenSet[Hashable]] = [] self._orbit_index: Optional[Dict[Hashable, int]] = None self._fitted: bool = False # ------------------------------------------------------------------ # # Public API: properties # ------------------------------------------------------------------ # @property def graph(self) -> nx.Graph: """ Underlying graph. :returns: Graph passed to the constructor. :rtype: nx.Graph """ return self._graph @property def node_attrs(self) -> Tuple[str, ...]: """ Node attribute keys used in WL initialization. :returns: Node-attribute keys. :rtype: tuple[str, ...] """ return self._cfg.node_attrs @property def edge_attrs(self) -> Tuple[str, ...]: """ Edge attribute keys used in WL refinement. :returns: Edge-attribute keys. :rtype: tuple[str, ...] """ return self._cfg.edge_attrs @property def max_iter(self) -> int: """ Maximum number of WL refinement iterations. :returns: Maximum refinement iterations. :rtype: int """ return self._cfg.max_iter @property def anchor_component(self) -> FrozenSet[Hashable]: """ Largest connected component of the fitted graph. This is a convenience “components-style” accessor. It is commonly used as an anchor set for match pruning and symmetry breaking. :returns: The node-set of the largest connected component. If multiple components share the maximum size, the one with the smallest (sorted) node is returned for determinism. :rtype: frozenset[hashable] :raises RuntimeError: If :meth:`fit` has not been called. """ self._ensure_fitted() comps = self._components_on_induced(nodes=None) if not comps: return frozenset() # Deterministic tie-break: size desc, then smallest node comps_sorted = sorted( comps, key=lambda c: (-len(c), min(c) if c else 0), ) return comps_sorted[0] # ------------------------------------------------------------------ # # Fitting and results # ------------------------------------------------------------------ #
[docs] def fit(self) -> AutoEst: """ Run WL-1 refinement and compute approximate orbits. :returns: The fitted estimator (``self``). :rtype: AutoEst """ self._initialize_colors() self._refine_colors() self._build_orbits() self._orbit_index = None self._fitted = True return self
@property def node_colors(self) -> Dict[Hashable, int]: """ Node-to-color mapping after refinement. :returns: Mapping node -> WL color id. :rtype: dict[hashable, int] """ self._ensure_fitted() return dict(self._colors) @property def orbits(self) -> List[FrozenSet[Hashable]]: """ WL-equivalence classes (approximate automorphism orbits). :returns: List of frozensets, each representing an orbit. :rtype: list[frozenset[hashable]] """ self._ensure_fitted() return list(self._orbits) @property def groups(self) -> List[List[Hashable]]: """ Orbits represented as sorted lists. :returns: List of sorted node lists. :rtype: list[list[hashable]] """ self._ensure_fitted() out: List[List[Hashable]] = [] for orb in self._orbits: out.append(sorted(orb, key=lambda x: x)) out.sort(key=lambda g: (len(g), g[0] if g else 0)) return out @property def orbit_index(self) -> Dict[Hashable, int]: """ Map each node to its orbit id. :returns: Mapping node -> orbit_id. :rtype: dict[hashable, int] """ self._ensure_fitted() if self._orbit_index is None: self._orbit_index = self._build_orbit_index() return dict(self._orbit_index) @property def n_orbits(self) -> int: """ Number of approximate orbits. :returns: Number of orbits. :rtype: int """ self._ensure_fitted() return len(self._orbits) @property def n_groups(self) -> int: """ Alias for :attr:`n_orbits`. :returns: Number of orbits. :rtype: int """ return self.n_orbits def __len__(self) -> int: """ Number of orbits (0 if not fitted). :returns: Orbit count or 0. :rtype: int """ return self.n_orbits if self._fitted else 0 def __repr__(self) -> str: """ Summary representation. :returns: Debug-friendly repr string. :rtype: str """ n_nodes = self._graph.number_of_nodes() n_orb: object = len(self) if self._fitted else "?" return ( f"<AutoEst | nodes={n_nodes} " f"orbits={n_orb} approx='WL-1' max_iter={self._cfg.max_iter}>" ) # ------------------------------------------------------------------ # # "Components style" (similar spirit to Automorphism) # ------------------------------------------------------------------ #
[docs] def components( self, nodes: Optional[Iterable[Hashable]] = None ) -> List[FrozenSet[Hashable]]: """ Compute connected components on an induced subgraph. This mirrors the "components" utilities you used around Automorphism. :param nodes: Subset of nodes to induce. If None, uses all nodes. :type nodes: iterable[hashable] or None :returns: Connected components as frozensets (deterministic order). :rtype: list[frozenset[hashable]] :raises RuntimeError: If not fitted. """ self._ensure_fitted() return self._components_on_induced(nodes)
[docs] def orbit_components( self, nodes: Optional[Iterable[Hashable]] = None ) -> List[FrozenSet[int]]: """ Components of the orbit-quotient graph restricted to an induced subgraph. - First restrict to `nodes` (or all nodes). - Collapse nodes to their orbit ids. - Build orbit-quotient adjacency based on edges between orbits. - Return connected components of orbit ids. This is useful when you have an "anchor component" defined as a set of pattern nodes and want to treat coupled orbits as a single unit. :param nodes: Subset of nodes. If None, uses all nodes. :type nodes: iterable[hashable] or None :returns: List of connected components in orbit-id space. :rtype: list[frozenset[int]] :raises RuntimeError: If not fitted. """ self._ensure_fitted() keep = self._normalize_nodes(nodes) orbit_idx = self.orbit_index q = self._orbit_quotient_graph(keep, orbit_idx) comps = [frozenset(c) for c in nx.connected_components(q)] return sorted(comps, key=lambda c: (len(c), min(c)))
# ------------------------------------------------------------------ # # Internal helpers # ------------------------------------------------------------------ # def _ensure_fitted(self) -> None: """ Ensure estimator is fitted. :raises RuntimeError: If not fitted. """ if not self._fitted: raise RuntimeError("Call 'fit()' before accessing results.") def _initialize_colors(self) -> None: """ Initialize node colors using degree and selected node attributes. """ palette: Dict[Tuple[Any, ...], int] = {} colors: Dict[Hashable, int] = {} next_color = 0 for node in self._graph.nodes(): label = self._initial_label(node) if label not in palette: palette[label] = next_color next_color += 1 colors[node] = palette[label] self._colors = colors def _initial_label(self, node: Hashable) -> Tuple[Any, ...]: """ Build initial label. :param node: Node id. :type node: hashable :returns: Tuple label (degree + attr values). :rtype: tuple """ degree = self._graph.degree(node) attrs = self._graph.nodes[node] vals = [attrs.get(k) for k in self._cfg.node_attrs] return (degree, *vals) def _refine_colors(self) -> None: """ Iterate WL refinement until convergence or max_iter. """ for _ in range(self._cfg.max_iter): new_colors, changed = self._refine_once() self._colors = new_colors if not changed: break def _refine_once(self) -> Tuple[Dict[Hashable, int], bool]: """ Single WL sweep. :returns: (new_colors, changed) :rtype: tuple[dict[hashable, int], bool] """ palette: Dict[Tuple[Any, ...], int] = {} new_colors: Dict[Hashable, int] = {} next_color = 0 changed = False for node in self._graph.nodes(): label = self._refined_label(node) if label not in palette: palette[label] = next_color next_color += 1 c = palette[label] new_colors[node] = c if c != self._colors.get(node): changed = True return new_colors, changed def _refined_label(self, node: Hashable) -> Tuple[Any, ...]: """ Combine current color with sorted neighbor signatures. :param node: Node id. :type node: hashable :returns: Refined label. :rtype: tuple """ base = self._colors[node] sigs: List[Tuple[Any, ...]] = [] for nbr in self._graph.neighbors(node): sigs.append(self._neighbor_signature(node, nbr)) sigs.sort() return (base, tuple(sigs)) def _neighbor_signature( self, node: Hashable, neighbor: Hashable ) -> Tuple[Any, ...]: """ Neighbor signature: (neighbor_color, edge_attr_1, ...). :param node: Central node. :type node: hashable :param neighbor: Neighbor node. :type neighbor: hashable :returns: Neighbor signature. :rtype: tuple """ edge_data = self._graph.get_edge_data(node, neighbor, default={}) edge_vals = [edge_data.get(k) for k in self._cfg.edge_attrs] return (self._colors[neighbor], *edge_vals) def _build_orbits(self) -> None: """ Group nodes by final colors. """ color_to_nodes: Dict[int, List[Hashable]] = {} for node, color in self._colors.items(): color_to_nodes.setdefault(color, []).append(node) orbits = [frozenset(v) for v in color_to_nodes.values()] self._orbits = sorted(orbits, key=lambda o: (len(o), min(o))) def _build_orbit_index(self) -> Dict[Hashable, int]: """ Build node -> orbit id mapping. :returns: Orbit index. :rtype: dict[hashable, int] """ idx: Dict[Hashable, int] = {} for i, orb in enumerate(self._orbits): for n in orb: idx[n] = i return idx def _normalize_nodes( self, nodes: Optional[Iterable[Hashable]] ) -> FrozenSet[Hashable]: """ Normalize subset nodes. :param nodes: Subset or None. :type nodes: iterable[hashable] or None :returns: Frozenset of nodes (validated). :rtype: frozenset[hashable] """ if nodes is None: return frozenset(self._graph.nodes()) keep = frozenset(nodes) unknown = [n for n in keep if n not in self._graph] if unknown: raise ValueError(f"Unknown nodes in subset: {unknown}") return keep def _components_on_induced( self, nodes: Optional[Iterable[Hashable]] ) -> List[FrozenSet[Hashable]]: """ Connected components on induced subgraph. :param nodes: Subset or None. :type nodes: iterable[hashable] or None :returns: Components in deterministic order. :rtype: list[frozenset[hashable]] """ keep = self._normalize_nodes(nodes) sub = self._graph.subgraph(keep) comps = [frozenset(c) for c in nx.connected_components(sub)] return sorted(comps, key=lambda c: (len(c), min(c))) def _orbit_quotient_graph( self, keep_nodes: FrozenSet[Hashable], orbit_idx: Dict[Hashable, int], ) -> nx.Graph: """ Build orbit quotient graph restricted to keep_nodes. :param keep_nodes: Nodes to keep. :type keep_nodes: frozenset[hashable] :param orbit_idx: Node->orbit id. :type orbit_idx: dict[hashable, int] :returns: Quotient graph in orbit-id space. :rtype: nx.Graph """ q = nx.Graph() for n in keep_nodes: q.add_node(orbit_idx[n]) for u, v in self._graph.edges(): if u not in keep_nodes or v not in keep_nodes: continue ou = orbit_idx[u] ov = orbit_idx[v] if ou != ov: q.add_edge(ou, ov) return q def _validate_host_nodes_in_mappings( self, mappings: List[Mapping[Hashable, Hashable]], orbit_idx: Dict[Hashable, int], ) -> None: """ Validate all host nodes are known. :param mappings: Mappings list. :type mappings: list[Mapping[hashable, hashable]] :param orbit_idx: Node->orbit id. :type orbit_idx: dict[hashable, int] :raises ValueError: If unknown nodes appear. """ missing: List[Hashable] = [] for mp in mappings: for h in mp.values(): if h not in orbit_idx: missing.append(h) if missing: raise ValueError( "Host nodes in mappings not present in fitted graph: " f"{sorted(set(missing))}" ) def _dedup_by_signature( self, mappings: List[Mapping[Hashable, Hashable]], sig_fn: Any, ) -> List[Mapping[Hashable, Hashable]]: """ Deduplicate by computed signature (stable representative selection). :param mappings: List of mappings. :type mappings: list[Mapping[hashable, hashable]] :param sig_fn: Function mapping mapping->signature. :type sig_fn: callable :returns: Deduplicated list. :rtype: list[Mapping[hashable, hashable]] """ mappings_sorted = sorted(mappings, key=sig_fn) out: List[Mapping[Hashable, Hashable]] = [] for _, grp in groupby(mappings_sorted, key=sig_fn): out.append(next(grp)) return out def _sort_orbits( self, pattern_orbits: Iterable[FrozenSet[Hashable]] ) -> Tuple[Tuple[Hashable, ...], ...]: """ Deterministically sort orbits. :param pattern_orbits: Orbits. :type pattern_orbits: iterable[frozenset[hashable]] :returns: Tuple of sorted orbit tuples. :rtype: tuple[tuple[hashable, ...], ...] """ orbs = [tuple(sorted(o, key=lambda x: x)) for o in pattern_orbits] orbs.sort(key=lambda o: (len(o), o[0] if o else 0)) return tuple(orbs) def _dedup_pattern_orbits_no_anchor( self, mappings: List[Mapping[Hashable, Hashable]], orbits: Tuple[Tuple[Hashable, ...], ...], ) -> List[Mapping[Hashable, Hashable]]: """ Pattern-orbit dedup without anchor. :param mappings: Mappings. :type mappings: list[Mapping[hashable, hashable]] :param orbits: Sorted orbit tuples. :type orbits: tuple[tuple[hashable, ...], ...] :returns: Deduplicated mappings. :rtype: list[Mapping[hashable, hashable]] """ def _sig(m: Mapping[Hashable, Hashable]) -> Tuple[Tuple[Hashable, ...], ...]: return tuple(tuple(sorted(m[p] for p in orb)) for orb in orbits) return self._dedup_by_signature(mappings, _sig) def _dedup_pattern_orbits_with_anchor( self, mappings: List[Mapping[Hashable, Hashable]], orbits: Tuple[Tuple[Hashable, ...], ...], anchor: FrozenSet[Hashable], ) -> List[Mapping[Hashable, Hashable]]: """ Pattern-orbit dedup with anchor (no pruning inside anchor nodes). :param mappings: Mappings. :type mappings: list[Mapping[hashable, hashable]] :param orbits: Sorted orbit tuples. :type orbits: tuple[tuple[hashable, ...], ...] :param anchor: Anchor nodes. :type anchor: frozenset[hashable] :returns: Deduplicated mappings. :rtype: list[Mapping[hashable, hashable]] """ anchor_nodes = tuple(sorted(anchor, key=lambda x: x)) free_orbits = tuple(orb for orb in orbits if not (set(orb) & anchor)) def _sig(m: Mapping[Hashable, Hashable]) -> Tuple[Any, ...]: free_part = tuple(tuple(sorted(m[p] for p in orb)) for orb in free_orbits) anchor_part = tuple(m[p] for p in anchor_nodes) return (free_part, anchor_part) return self._dedup_by_signature(mappings, _sig)
# --------------------------------------------------------------------------- # # Convenience function (kept minimal) # --------------------------------------------------------------------------- #
[docs] def estimate_automorphism_groups( graph: nx.Graph, node_attrs: Optional[Iterable[str]] = None, edge_attrs: Optional[Iterable[str]] = None, max_iter: int = 10, ) -> AutoEst: """ Convenience function to fit :class:`AutoEst`. :param graph: Input NetworkX graph. :type graph: nx.Graph :param node_attrs: Node attribute keys to include in WL initialization. :type node_attrs: iterable[str] or None :param edge_attrs: Edge attribute keys to include in neighbor signatures. :type edge_attrs: iterable[str] or None :param max_iter: Maximum WL iterations. :type max_iter: int :returns: Fitted estimator. :rtype: AutoEst """ node_list = list(node_attrs) if node_attrs is not None else None edge_list = list(edge_attrs) if edge_attrs is not None else None return AutoEst( graph=graph, node_attrs=node_list, edge_attrs=edge_list, max_iter=max_iter, ).fit()