Source code for synkit.Graph.Matcher.partial_matcher

from __future__ import annotations

from itertools import combinations
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union

import networkx as nx

from synkit.Graph.Matcher.subgraph_matcher import SubgraphSearchEngine
from synkit.Synthesis.Reactor.strategy import Strategy
from synkit.Graph.Matcher.auto_est import AutoEst  # WL-1 orbit estimator
from synkit.Graph.Matcher.dedup_matches import deduplicate_matches_with_anchor

MappingDict = Dict[int, int]

__all__ = ["PartialMatcher"]


[docs] class PartialMatcher: """ Component-subset helper for pattern→host subgraph matching. This matcher treats each connected component of the pattern as an independent "micro-pattern" and searches for consistent embeddings of subsets of these components into one or more host graphs. It can behave like a classic "partial matcher" (searching all component counts) or like a strict full-pattern matcher, depending on the ``partial`` flag. Internally, all embeddings for each pair (host, pattern component) are pre-computed once and then re-used for all combinations. This significantly reduces redundant calls to :class:`SubgraphSearchEngine` when exploring many subsets. Optionally, approximate WL-1 automorphism orbits can be used to prune embeddings that are equivalent under host symmetries via :paramref:`prune_auto`. Parameters ---------- host : nx.Graph | Sequence[nx.Graph] Single host graph or sequence of host graphs. pattern : nx.Graph Pattern graph whose connected components act as building blocks. node_attrs : list[str] Node attribute keys enforced equal during matching. edge_attrs : list[str] Edge attribute keys enforced equal during matching. strategy : Strategy, optional Matching strategy forwarded to :class:`SubgraphSearchEngine`. max_results : int | None, optional Global cap on number of embeddings to store. If ``None``, no explicit cap is applied. partial : bool, optional If ``True``, auto-mode (``k=None``) searches all component counts from full pattern down to 1. If ``False``, auto-mode only tries ``k = n_components`` (i.e. full-pattern matching only). threshold : int | None, optional Optional cap on embeddings *per (host, component)* pairing. If exceeded, that pairing is treated as "no valid embeddings" and skipped. Defaults to :data:`SubgraphSearchEngine.DEFAULT_THRESHOLD`. pre_filter : bool, optional If ``True``, enable the cheap Cartesian-product pre-filter in :class:`SubgraphSearchEngine` for each (host, component) pair. prune_auto : bool, optional If ``True``, apply approximate automorphism-based pruning on the final list of embeddings using WL-1 orbits computed by :class:`AutoEst`. For safety, pruning is only applied when there is a single host graph. Defaults to ``False``. wl_max_iter : int, optional Maximum number of WL refinement iterations in :class:`AutoEst` when :paramref:`prune_auto` is enabled. Defaults to 10. """ def __init__( self, host: Union[nx.Graph, Sequence[nx.Graph]], pattern: nx.Graph, node_attrs: List[str], edge_attrs: List[str], *, strategy: Strategy = Strategy.COMPONENT, max_results: Optional[int] = None, partial: bool = True, threshold: Optional[int] = None, pre_filter: bool = False, prune_auto: bool = False, wl_max_iter: int = 10, ) -> None: if isinstance(host, nx.Graph): self.hosts: List[nx.Graph] = [host] elif isinstance(host, Sequence): self.hosts = list(host) else: raise TypeError( "host must be a networkx.Graph or a sequence of such graphs" ) self.pattern: nx.Graph = pattern self.node_attrs: List[str] = node_attrs self.edge_attrs: List[str] = edge_attrs self.strategy: Strategy = strategy self.max_results: Optional[int] = max_results self.partial: bool = partial self._threshold: int = ( threshold if threshold is not None else SubgraphSearchEngine.DEFAULT_THRESHOLD ) self._pre_filter: bool = pre_filter # WL-1 approximate automorphism pruning settings self._prune_auto: bool = bool(prune_auto) self._wl_max_iter: int = int(wl_max_iter) # WL-style approximate embedding count self._approx_embedding_count: Optional[int] = None self._pattern_ccs: List[nx.Graph] = self._split_pattern_components() self._host_embeddings: List[List[List[MappingDict]]] = [] self._precompute_embeddings() mappings = self._match_components(k=None) if self._prune_auto: mappings = self._prune_automorphic_mappings(mappings) self._mappings: List[MappingDict] = mappings # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _split_pattern_components(self) -> List[nx.Graph]: """ Split the pattern into connected components. :returns: List of connected component subgraphs. :rtype: list[nx.Graph] :raises ValueError: If the pattern has no components. """ components = [ self.pattern.subgraph(c).copy() for c in nx.connected_components(self.pattern) ] if not components: raise ValueError("Pattern graph has no components.") return components def _precompute_embeddings(self) -> None: """ Pre-compute embeddings for each (host, component) pair. The results are stored in :attr:`_host_embeddings` as a nested list indexed as ``[host_index][component_index]``. """ host_embeddings: List[List[List[MappingDict]]] = [] for host in self.hosts: comp_embeddings: List[List[MappingDict]] = [] for pat_cc in self._pattern_ccs: embeddings = SubgraphSearchEngine.find_subgraph_mappings( host, pat_cc, node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, strategy=self.strategy, max_results=self.max_results, strict_cc_count=False, threshold=self._threshold, pre_filter=self._pre_filter, ) comp_embeddings.append(embeddings) host_embeddings.append(comp_embeddings) self._host_embeddings = host_embeddings def _prune_automorphic_mappings( self, mappings: List[MappingDict], ) -> List[MappingDict]: """ Approximate automorphism-based pruning using WL-1 orbits. This uses :class:`AutoEst` on the (single) host graph to collapse mappings that hit the same multiset of WL-orbit indices. For multiple hosts, the input list is returned unchanged to avoid node-id collisions across different hosts. :param mappings: Flat list of pattern→host mappings. :type mappings: list[MappingDict] :returns: Possibly pruned list of mappings. :rtype: list[MappingDict] """ if not mappings: return [] if len(self.hosts) != 1: return mappings host = self.hosts[0] est = AutoEst( graph=host, node_attrs=self.node_attrs, edge_attrs=self.edge_attrs, max_iter=self._wl_max_iter, ).fit() # AutoEst.deduplicate expects Mapping[hashable, hashable] maps = deduplicate_matches_with_anchor( matches=mappings, host_orbits=est.orbits, host_anchor=est.anchor_component ) return maps @staticmethod def _build_label_hist( graph: nx.Graph, node_attrs: Sequence[str], ) -> Dict[Tuple[Any, ...], int]: """ Build WL-style initial label histogram for a graph. The label is ``(degree, attrs[node_attr_0], attrs[node_attr_1], ...)``. :param graph: Input NetworkX graph. :type graph: nx.Graph :param node_attrs: Node attribute keys to include in the label. :type node_attrs: Sequence[str] :returns: Mapping from label tuple to count. :rtype: dict[tuple, int] """ hist: Dict[Tuple[Any, ...], int] = {} for node in graph.nodes(): degree = graph.degree(node) attrs = graph.nodes[node] values = [attrs.get(key) for key in node_attrs] label = (degree, *values) hist[label] = hist.get(label, 0) + 1 return hist @staticmethod def _approx_pair_count( host_hist: Dict[Tuple[Any, ...], int], comp_hist: Dict[Tuple[Any, ...], int], ) -> int: """ Approximate number of label-consistent injective mappings for a (host, component) pair, ignoring adjacency. :param host_hist: Label histogram for the host graph. :type host_hist: dict :param comp_hist: Label histogram for the pattern component. :type comp_hist: dict :returns: Upper-bound estimate of injective mappings. :rtype: int """ total = 1 for label, p_count in comp_hist.items(): h_count = host_hist.get(label, 0) if h_count < p_count: return 0 for i in range(p_count): total *= h_count - i return total # ------------------------------------------------------------------ # WL-style approximate helpers (factorised to reduce complexity) # ------------------------------------------------------------------ def _normalise_k_values(self, k: Optional[int]) -> List[int]: """ Normalise ``k`` to a list of component counts to consider. :param k: Desired number of components or ``None`` for auto-mode. :type k: int | None :returns: List of component counts to iterate over. :rtype: list[int] :raises ValueError: If ``k`` is outside ``[1, n_components]``. """ n_cc = len(self._pattern_ccs) if k is not None: if k <= 0 or k > n_cc: raise ValueError(f"k must be between 1 and {n_cc}") return [k] if not self.partial: return [n_cc] return list(range(n_cc, 0, -1)) def _build_host_hists(self) -> List[Dict[Tuple[Any, ...], int]]: """ Build WL-style label histograms for all host graphs. :returns: List of label histograms, one per host. :rtype: list[dict[tuple, int]] """ host_hists: List[Dict[Tuple[Any, ...], int]] = [] for host in self.hosts: hist = self._build_label_hist(host, self.node_attrs) host_hists.append(hist) return host_hists def _build_comp_hists(self) -> List[Dict[Tuple[Any, ...], int]]: """ Build WL-style label histograms for all pattern components. :returns: List of label histograms, one per component. :rtype: list[dict[tuple, int]] """ comp_hists: List[Dict[Tuple[Any, ...], int]] = [] for pat_cc in self._pattern_ccs: hist = self._build_label_hist(pat_cc, self.node_attrs) comp_hists.append(hist) return comp_hists def _compute_pair_counts( self, host_hists: List[Dict[Tuple[Any, ...], int]], comp_hists: List[Dict[Tuple[Any, ...], int]], ) -> List[List[int]]: """ Pre-compute approximate counts for all (host, component) pairs. :param host_hists: WL label histograms for hosts. :type host_hists: list[dict[tuple, int]] :param comp_hists: WL label histograms for components. :type comp_hists: list[dict[tuple, int]] :returns: Matrix of approximate counts indexed as [host][component]. :rtype: list[list[int]] """ pair_counts: List[List[int]] = [] for host_hist in host_hists: row: List[int] = [] for comp_hist in comp_hists: count = self._approx_pair_count(host_hist, comp_hist) row.append(count) pair_counts.append(row) return pair_counts def _product_for_combo( self, row: List[int], combo: Sequence[int], ) -> int: """ Compute product of pair counts for a single host/combination. :param row: List of pair counts for one host over all components. :type row: list[int] :param combo: Selected component indices. :type combo: Sequence[int] :returns: Product of counts, or 0 if any factor is 0. :rtype: int """ prod_val = 1 for cc_idx in combo: pair_count = row[cc_idx] if pair_count == 0: return 0 prod_val *= pair_count return prod_val def _aggregate_pair_counts( self, pair_counts: List[List[int]], k_values: List[int], ) -> int: """ Aggregate per-pair estimates over component subsets and hosts. :param pair_counts: Matrix of approximate counts [host][component]. :type pair_counts: list[list[int]] :param k_values: Component counts to consider. :type k_values: list[int] :returns: Aggregated estimate (may be truncated by ``max_results``). :rtype: int """ total_est = 0 n_cc = len(self._pattern_ccs) cc_indices = range(n_cc) for k_try in k_values: for combo in combinations(cc_indices, k_try): for host_idx, row in enumerate(pair_counts): prod_val = self._product_for_combo(row, combo) if prod_val == 0: continue total_est += prod_val if self.max_results and total_est >= self.max_results: return int(self.max_results) return int(total_est) # ------------------------------------------------------------------ # Core matching logic # ------------------------------------------------------------------ def _match_components(self, k: Optional[int] = None) -> List[MappingDict]: """ Internal search – returns a *flat* list of embeddings. :param k: Number of connected components of the pattern to use. * If an integer, the search is restricted to subsets of exactly ``k`` pattern components. * If ``None``, behaviour depends on :attr:`partial`: - ``partial=False`` → only ``k = n_components`` is used. - ``partial=True`` → searches all feasible ``k`` from ``n_components`` down to 1. :type k: int | None :returns: Flat list of pattern→host node mappings. :rtype: list[MappingDict] """ if k is not None: return self._match_fixed_k(k) if not self.partial: return self._match_fixed_k(len(self._pattern_ccs)) return self._match_all_k() def _match_all_k(self) -> List[MappingDict]: """ Aggregate embeddings over all feasible component counts. This tries ``k = n_cc, n_cc-1, ..., 1`` and stops once :attr:`max_results` is reached (if set). :returns: Flat list of pattern→host node mappings. :rtype: list[MappingDict] """ all_mappings: List[MappingDict] = [] n_cc = len(self._pattern_ccs) for k_try in range(n_cc, 0, -1): mappings = self._match_fixed_k(k_try) if not mappings: continue for emb in mappings: all_mappings.append(emb) if self.max_results and len(all_mappings) >= self.max_results: return all_mappings return all_mappings def _match_fixed_k(self, k: int) -> List[MappingDict]: """ Match using exactly ``k`` connected components of the pattern. :param k: Number of connected components to select. :type k: int :returns: Flat list of pattern→host node mappings. :rtype: list[MappingDict] :raises ValueError: If ``k`` is outside ``[1, n_components]``. """ n_cc = len(self._pattern_ccs) if k <= 0 or k > n_cc: raise ValueError(f"k must be between 1 and {n_cc}") all_mappings: List[MappingDict] = [] cc_indices = range(n_cc) for combo in combinations(cc_indices, k): for host_index, _host in enumerate(self.hosts): self._backtrack_components( combo=combo, host_index=host_index, level=0, used_nodes=set(), accum={}, out=all_mappings, ) if self.max_results and len(all_mappings) >= self.max_results: return all_mappings return all_mappings def _backtrack_components( self, combo: Sequence[int], host_index: int, level: int, used_nodes: Set[int], accum: MappingDict, out: List[MappingDict], ) -> None: """ Backtracking across selected components within a single host. :param combo: Sequence of component indices to match in order. :type combo: Sequence[int] :param host_index: Index of the current host in :attr:`hosts`. :type host_index: int :param level: Current recursion depth (index in ``combo``). :type level: int :param used_nodes: Set of host node ids already used. :type used_nodes: set[int] :param accum: Accumulated pattern→host mapping. :type accum: MappingDict :param out: List where completed mappings are appended. :type out: list[MappingDict] """ if self.max_results and len(out) >= self.max_results: return if level == len(combo): out.append(accum.copy()) return cc_idx = combo[level] embeddings = self._host_embeddings[host_index][cc_idx] if not embeddings: return for emb in embeddings: mapped = set(emb.values()) if mapped & used_nodes: continue new_used = used_nodes | mapped new_accum = {**accum, **emb} self._backtrack_components( combo=combo, host_index=host_index, level=level + 1, used_nodes=new_used, accum=new_accum, out=out, ) # ------------------------------------------------------------------ # WL-style approximate embedding count # ------------------------------------------------------------------
[docs] def estimate_embeddings_wl(self, k: Optional[int] = None) -> "PartialMatcher": """ Estimate the number of embeddings using WL-style initial labels. This is a **cheap, approximate upper bound** that: * Builds WL-style labels ``(degree, node_attrs...)`` on the host and pattern components. * For each (host, component) pair, estimates the number of label-consistent injective mappings ignoring adjacency, via a product of falling factorials per label class. * Aggregates these per-pair estimates over subsets of pattern components using the same semantics as :meth:`_match_components`. No calls to :class:`SubgraphSearchEngine` or backtracking are performed. The result is stored in :attr:`approx_embedding_count`. :param k: Number of pattern components to use. If ``None``, behaviour mirrors :meth:`_match_components`: * ``partial=False`` → use only full pattern (``k=n_cc``). * ``partial=True`` → aggregate over all k from ``n_cc`` down to 1. :type k: int | None :returns: The estimator itself (for chained use). :rtype: PartialMatcher """ k_values = self._normalise_k_values(k) host_hists = self._build_host_hists() comp_hists = self._build_comp_hists() pair_counts = self._compute_pair_counts(host_hists, comp_hists) total_est = self._aggregate_pair_counts(pair_counts, k_values) if self.max_results and total_est >= self.max_results: self._approx_embedding_count = int(self.max_results) else: self._approx_embedding_count = int(total_est) return self
@property def approx_embedding_count(self) -> int: """ WL-style approximate embedding count. :returns: Last estimated embedding count. :rtype: int :raises RuntimeError: If :meth:`estimate_embeddings_wl` has not been called. """ if self._approx_embedding_count is None: raise RuntimeError( "Call 'estimate_embeddings_wl()' before accessing " "'approx_embedding_count'." ) return self._approx_embedding_count # ------------------------------------------------------------------ # Public instance helpers # ------------------------------------------------------------------
[docs] def get_mappings(self) -> List[MappingDict]: """ Return the list of discovered embeddings (auto-computed). :returns: List of pattern→host node mappings. :rtype: list[MappingDict] """ return self._mappings
@property def num_mappings(self) -> int: """ Number of embeddings found. :returns: Count of discovered embeddings. :rtype: int """ return len(self._mappings) @property def num_pattern_components(self) -> int: """ Number of connected components in the pattern graph. :returns: Number of pattern connected components. :rtype: int """ return len(self._pattern_ccs) @property def threshold(self) -> int: """ Per-(host, component) embedding threshold. :returns: Threshold passed to :class:`SubgraphSearchEngine`. :rtype: int """ return self._threshold @property def pre_filter(self) -> bool: """ Whether the cheap pre-filter is enabled. :returns: Current value of the pre-filter flag. :rtype: bool """ return self._pre_filter # Iteration support ------------------------------------------------- def __iter__(self) -> Iterator[MappingDict]: """ Iterate over discovered embeddings. :returns: Iterator over mapping dictionaries. :rtype: Iterator[MappingDict] """ return iter(self._mappings) # Niceties ---------------------------------------------------------- def __repr__(self) -> str: """ Representation string for debugging. :returns: Short summary of matcher state. :rtype: str """ return ( f"<PartialMatcher pattern_ccs={self.num_pattern_components} " f"hosts={len(self.hosts)} mappings={self.num_mappings} " f"partial={self.partial} threshold={self._threshold} " f"pre_filter={self._pre_filter} prune_auto={self._prune_auto}>" ) __str__ = __repr__ @property def help(self) -> str: """ Return the full module docstring. :returns: Module-level documentation string. :rtype: str """ return __doc__ or "" # ------------------------------------------------------------------ # Functional/staticmethod wrapper # ------------------------------------------------------------------
[docs] @staticmethod def find_partial_mappings( host: Union[nx.Graph, Sequence[nx.Graph]], pattern: nx.Graph, *, node_attrs: List[str], edge_attrs: List[str], k: Optional[int] = None, strategy: Strategy = Strategy.COMPONENT, max_results: Optional[int] = None, partial: bool = True, threshold: Optional[int] = None, pre_filter: bool = False, prune_auto: bool = False, wl_max_iter: int = 10, ) -> List[MappingDict]: """ Stateless convenience wrapper – one-liner for users in a hurry. This mirrors the OO API but avoids explicitly instantiating the matcher in user code. :param host: A single host graph or a sequence of host graphs. :type host: nx.Graph | Sequence[nx.Graph] :param pattern: Pattern graph whose connected components are used as building blocks. :type pattern: nx.Graph :param node_attrs: Node attribute keys to enforce equality on during matching. :type node_attrs: list[str] :param edge_attrs: Edge attribute keys to enforce equality on during matching. :type edge_attrs: list[str] :param k: If an integer, restricts the search to subsets of exactly ``k`` pattern connected components. If ``None``, behaviour follows the ``partial`` flag. :type k: int | None :param strategy: Matching strategy forwarded to :class:`SubgraphSearchEngine`. :type strategy: Strategy :param max_results: Optional global cap on the number of embeddings to return. :type max_results: int | None :param partial: If ``True``, all component counts are tried in auto-mode. If ``False``, only the full pattern is used. :type partial: bool :param threshold: Optional per-(host, component) embedding cap forwarded to :class:`SubgraphSearchEngine`. :type threshold: int | None :param pre_filter: Whether to enable the cheap pre-filter in :class:`SubgraphSearchEngine`. :type pre_filter: bool :param prune_auto: If ``True``, apply WL-1-based approximate automorphism pruning on the final mappings. :type prune_auto: bool :param wl_max_iter: Maximum number of WL iterations for the internal :class:`AutoEst` if :paramref:`prune_auto` is enabled. :type wl_max_iter: int :returns: Flat list of pattern→host node mappings. :rtype: list[MappingDict] """ matcher = PartialMatcher( host=host, pattern=pattern, node_attrs=node_attrs, edge_attrs=edge_attrs, strategy=strategy, max_results=max_results, partial=partial, threshold=threshold, pre_filter=pre_filter, prune_auto=prune_auto, wl_max_iter=wl_max_iter, ) if k is not None: return matcher._match_components(k) return matcher.get_mappings()