from __future__ import annotations
from dataclasses import dataclass
from time import perf_counter
from typing import Any, Dict, List, Optional, Set, Tuple
import networkx as nx
from ._common import (
SymmetryConfig,
approx_automorphism_count_from_cells,
build_fast_signature,
edge_token,
graph_key_from_order,
hash_text,
node_token,
prepare_graph,
)
@dataclass(frozen=True)
class _WLState:
"""
Internal cached WL state.
:param colors:
Final node colors after WL refinement.
:type colors: Dict[Any, str]
:param cells:
Final WL color cells as sorted node lists.
:type cells: List[List[Any]]
:param orbits:
Approximate node orbits induced by final WL colors.
:type orbits: List[Set[Any]]
:param color_hist:
Histogram of final colors.
:type color_hist: Dict[str, int]
:param iters_run:
Number of WL iterations actually performed.
:type iters_run: int
:param stabilized:
Whether refinement stabilized before the iteration limit.
:type stabilized: bool
:param canonical_order:
Deterministic node order induced by final WL colors.
:type canonical_order: List[Any]
:param approx_automorphism_count:
Approximate automorphism count derived from WL cells.
:type approx_automorphism_count: Optional[int]
"""
colors: Dict[Any, str]
cells: List[List[Any]]
orbits: List[Set[Any]]
color_hist: Dict[str, int]
iters_run: int
stabilized: bool
canonical_order: List[Any]
approx_automorphism_count: Optional[int]
[docs]
@dataclass(frozen=True)
class WLCanonicalResult:
"""
Approximate canonicalization result from WL refinement.
This mirrors the exact canonicalizer style, but remains approximate.
:param canon_graph:
Graph canonically relabeled according to the WL order.
:type canon_graph: nx.DiGraph
:param graph_type:
Graph representation type, e.g. ``"bipartite"`` or ``"species"``.
:type graph_type: str
:param canonical_order:
Deterministic node order induced by the final WL partition.
:type canonical_order: List[Any]
:param canonical_key:
Canonical graph key derived from the WL order.
:type canonical_key: Any
:param automorphism_count:
Approximate automorphism count from WL cells.
:type automorphism_count: Optional[int]
:param orbits:
Approximate node orbits from the final WL cells.
:type orbits: List[Set[Any]]
:param colors:
Final WL color mapping.
:type colors: Dict[Any, str]
:param color_hist:
Histogram of final WL colors.
:type color_hist: Dict[str, int]
:param iters_run:
Number of WL iterations actually performed.
:type iters_run: int
:param stabilized:
Whether WL refinement stabilized before the maximum iteration count.
:type stabilized: bool
:param exact:
Always ``False`` for WL refinement.
:type exact: bool
:param elapsed_seconds:
Runtime in seconds for building the result object.
:type elapsed_seconds: float
"""
canon_graph: nx.DiGraph
graph_type: str
canonical_order: List[Any]
canonical_key: Any
automorphism_count: Optional[int]
orbits: List[Set[Any]]
colors: Dict[Any, str]
color_hist: Dict[str, int]
iters_run: int
stabilized: bool
exact: bool
elapsed_seconds: float
[docs]
class WLCanonicalizer:
"""
Fast approximate canonicalizer for SynKit CRN graphs using direction-aware
1-WL refinement.
This class is designed as a lightweight companion to the exact CRN
canonicalizer. It gives:
- deterministic WL-based canonical relabeling
- approximate orbit partitions
- approximate automorphism counts from WL cells
- fast signatures for cheap prefiltering
Compared with the exact canonicalizer, this class is much faster but not
guaranteed to distinguish all non-isomorphic graphs or recover exact
automorphism groups.
:param source:
Input CRN representation. This may be a prepared
:class:`networkx.DiGraph`, an object exposing ``to_digraph()``, or any
object accepted by :func:`prepare_graph`.
:type source: Any
:param include_rule:
Whether rule nodes should be included in the prepared graph.
:type include_rule: bool
:param include_stoich:
Whether stoichiometric edge attributes should be preserved during graph
preparation.
:type include_stoich: bool
:param n_iter:
Maximum number of WL refinement iterations.
:type n_iter: int
:param digest_size:
Digest size used when hashing node and edge signatures.
:type digest_size: int
:param include_in_neighbors:
Whether incoming neighbors should contribute to refinement.
:type include_in_neighbors: bool
:param include_out_neighbors:
Whether outgoing neighbors should contribute to refinement.
:type include_out_neighbors: bool
:param estimate_automorphisms:
Whether to compute an approximate automorphism count from the final WL
cells.
:type estimate_automorphisms: bool
:param automorphism_cap:
Cap applied to approximate automorphism counts.
:type automorphism_cap: int
:param config:
Symmetry semantics configuration controlling which node and edge
attributes participate in WL coloring.
:type config: Optional[SymmetryConfig]
Example
-------
.. code-block:: python
from synkit.CRN.Sym import WLCanonicalizer, SymmetryConfig
wl = WLCanonicalizer(
syn.to_digraph(),
include_rule=True,
config=SymmetryConfig.topological(),
)
print(wl.has_nontrivial_automorphism())
print(wl.orbits())
print(wl.canonical_order())
print(wl.summary()["automorphism_count"])
"""
def __init__(
self,
source: Any,
*,
include_rule: bool = True,
include_stoich: bool = True,
n_iter: int = 20,
digest_size: int = 16,
include_in_neighbors: bool = True,
include_out_neighbors: bool = True,
estimate_automorphisms: bool = True,
automorphism_cap: int = 10**18,
config: Optional[SymmetryConfig] = None,
) -> None:
"""
Initialize the WL canonicalizer.
:param source:
Input CRN representation.
:type source: Any
:param include_rule:
Whether rule nodes should be included in the prepared graph.
:type include_rule: bool
:param include_stoich:
Whether stoichiometric edge attributes should be preserved.
:type include_stoich: bool
:param n_iter:
Maximum number of WL refinement rounds.
:type n_iter: int
:param digest_size:
Digest size used when hashing WL signatures.
:type digest_size: int
:param include_in_neighbors:
Whether incoming neighborhoods should be used.
:type include_in_neighbors: bool
:param include_out_neighbors:
Whether outgoing neighborhoods should be used.
:type include_out_neighbors: bool
:param estimate_automorphisms:
Whether to estimate automorphism count from final WL cells.
:type estimate_automorphisms: bool
:param automorphism_cap:
Maximum cap for approximate automorphism counting.
:type automorphism_cap: int
:param config:
Symmetry semantics configuration.
:type config: Optional[SymmetryConfig]
:returns:
None.
:rtype: None
Example
-------
.. code-block:: python
wl = WLCanonicalizer(
syn.to_digraph(),
include_rule=True,
n_iter=20,
digest_size=16,
)
"""
self.source = source
self.include_rule = bool(include_rule)
self.include_stoich = bool(include_stoich)
self.n_iter = int(n_iter)
self.digest_size = int(digest_size)
self.include_in_neighbors = bool(include_in_neighbors)
self.include_out_neighbors = bool(include_out_neighbors)
self.estimate_automorphisms = bool(estimate_automorphisms)
self.automorphism_cap = int(automorphism_cap)
self.config = config or SymmetryConfig.semantic()
self._G, self._graph_type = prepare_graph(
source,
include_rule=self.include_rule,
include_stoich=self.include_stoich,
)
self._state_cache: Optional[_WLState] = None
self._summary_cache: Optional[WLCanonicalResult] = None
self._fast_signature: Optional[Tuple[Any, ...]] = None
self._cache_key_last: Optional[Tuple[Any, ...]] = None
def __repr__(self) -> str:
"""
Return a concise representation.
:returns:
String representation.
:rtype: str
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl)
"""
return (
f"WLCanonicalizer(include_rule={self.include_rule}, "
f"graph_type={self.graph_type}, n_iter={self.n_iter}, "
f"digest_size={self.digest_size})"
)
@property
def G(self) -> nx.DiGraph:
"""
Return the prepared graph.
:returns:
Prepared directed graph.
:rtype: nx.DiGraph
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.G.number_of_nodes(), wl.G.number_of_edges())
"""
return self._G
@property
def graph_type(self) -> str:
"""
Return the graph representation type.
:returns:
Graph representation type.
:rtype: str
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn, include_rule=True)
print(wl.graph_type)
"""
return self._graph_type
def _cache_key(self) -> Tuple[Any, ...]:
"""
Build a conservative cache key for the current graph and parameters.
:returns:
Cache key tuple.
:rtype: Tuple[Any, ...]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl._cache_key())
"""
return (
id(self.G),
self.G.number_of_nodes(),
self.G.number_of_edges(),
self.include_rule,
self.include_stoich,
self.n_iter,
self.digest_size,
self.include_in_neighbors,
self.include_out_neighbors,
self.estimate_automorphisms,
self.automorphism_cap,
self.config,
)
def _node_seed(self, v: Any) -> str:
"""
Compute the initial WL color for one node.
The seed combines the semantic node token and the directed degree.
:param v:
Node identifier.
:type v: Any
:returns:
Initial hashed node color.
:rtype: str
"""
tok = node_token(self.G.nodes[v], self.config)
deg = (self.G.in_degree(v), self.G.out_degree(v))
return hash_text(f"N|{tok}|{deg}", digest_size=self.digest_size)
def _edge_sig(self, attrs: Dict[str, Any]) -> str:
"""
Compute a hashed signature for one edge attribute dictionary.
:param attrs:
Edge attributes.
:type attrs: Dict[str, Any]
:returns:
Hashed edge signature.
:rtype: str
"""
return hash_text(
f"E|{edge_token(attrs, self.config)}",
digest_size=self.digest_size,
)
def _edge_sig_between(self, u: Any, v: Any) -> str:
"""
Return a stable signature for the edge between two nodes.
For multigraphs, the minimum signature over parallel edges is used to
keep the behavior deterministic.
:param u:
Source node.
:type u: Any
:param v:
Target node.
:type v: Any
:returns:
Stable edge signature.
:rtype: str
"""
data = self.G.get_edge_data(u, v, default=None)
if data is None:
return self._edge_sig({})
if self.G.is_multigraph():
sigs: List[str] = []
if isinstance(data, dict):
for _, attrs in data.items():
if isinstance(attrs, dict):
sigs.append(self._edge_sig(attrs))
return min(sigs) if sigs else self._edge_sig({})
if isinstance(data, dict):
return self._edge_sig(data)
return self._edge_sig({})
def _neighbors_items(
self,
colors: Dict[Any, str],
v: Any,
*,
direction: str,
) -> List[str]:
"""
Collect colored neighbor-edge descriptors for one node.
:param colors:
Current node colors.
:type colors: Dict[Any, str]
:param v:
Node identifier.
:type v: Any
:param direction:
Neighborhood direction, one of ``"in"``, ``"out"``, or ``"undir"``.
:type direction: str
:returns:
Sorted color-edge descriptors.
:rtype: List[str]
"""
items: List[str] = []
if direction == "in":
if not self.G.is_directed():
direction = "undir"
else:
for u in self.G.predecessors(v):
items.append(f"{colors[u]}#{self._edge_sig_between(u, v)}")
if direction == "out":
if not self.G.is_directed():
direction = "undir"
else:
for u in self.G.successors(v):
items.append(f"{colors[u]}#{self._edge_sig_between(v, u)}")
if direction == "undir":
for u in self.G.neighbors(v):
es = self._edge_sig_between(v, u) if self.G.has_edge(v, u) else ""
if not es and self.G.has_edge(u, v):
es = self._edge_sig_between(u, v)
items.append(f"{colors[u]}#{es}")
items.sort()
return items
def _refine_once(self, colors: Dict[Any, str]) -> Dict[Any, str]:
"""
Perform one WL refinement round.
:param colors:
Current node colors.
:type colors: Dict[Any, str]
:returns:
Refined node colors.
:rtype: Dict[Any, str]
"""
new_colors: Dict[Any, str] = {}
for v in self.G.nodes():
parts: List[str] = [colors[v]]
if self.include_in_neighbors:
parts.append(
"IN["
+ "|".join(self._neighbors_items(colors, v, direction="in"))
+ "]"
)
if self.include_out_neighbors:
parts.append(
"OUT["
+ "|".join(self._neighbors_items(colors, v, direction="out"))
+ "]"
)
new_colors[v] = hash_text("||".join(parts), digest_size=self.digest_size)
return new_colors
@staticmethod
def _colors_equal(a: Dict[Any, str], b: Dict[Any, str]) -> bool:
"""
Compare two color mappings exactly.
:param a:
First color mapping.
:type a: Dict[Any, str]
:param b:
Second color mapping.
:type b: Dict[Any, str]
:returns:
``True`` if both mappings are identical.
:rtype: bool
"""
if a.keys() != b.keys():
return False
return all(a[k] == b[k] for k in a)
@staticmethod
def _buckets_from_colors(colors: Dict[Any, str]) -> Dict[str, List[Any]]:
"""
Group nodes by final color.
:param colors:
Node-to-color mapping.
:type colors: Dict[Any, str]
:returns:
Color buckets.
:rtype: Dict[str, List[Any]]
"""
buckets: Dict[str, List[Any]] = {}
for v, c in colors.items():
buckets.setdefault(c, []).append(v)
return buckets
@staticmethod
def _orbits_from_buckets(buckets: Dict[str, List[Any]]) -> List[Set[Any]]:
"""
Build approximate orbit sets from color buckets.
:param buckets:
Color buckets.
:type buckets: Dict[str, List[Any]]
:returns:
Approximate orbit sets.
:rtype: List[Set[Any]]
"""
items = sorted(buckets.items(), key=lambda kv: (kv[0], len(kv[1])))
out: List[Set[Any]] = []
for _, nodes in items:
out.append(set(sorted(nodes, key=str)))
return out
@staticmethod
def _canonical_order_from_colors(
G: nx.DiGraph,
colors: Dict[Any, str],
) -> List[Any]:
"""
Build a deterministic node order from final colors.
:param G:
Input graph.
:type G: nx.DiGraph
:param colors:
Final color mapping.
:type colors: Dict[Any, str]
:returns:
Deterministic canonical order.
:rtype: List[Any]
"""
return sorted(G.nodes(), key=lambda v: (colors[v], str(v)))
def _run(self) -> _WLState:
"""
Run WL refinement once and cache the result.
:returns:
Internal cached WL state.
:rtype: _WLState
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
state = wl._run()
print(state.iters_run, state.stabilized)
"""
key = self._cache_key()
if self._state_cache is not None and self._cache_key_last == key:
return self._state_cache
colors: Dict[Any, str] = {v: self._node_seed(v) for v in self.G.nodes()}
stabilized = False
iters_run = 0
for it in range(self.n_iter):
iters_run = it + 1
new_colors = self._refine_once(colors)
if self._colors_equal(new_colors, colors):
stabilized = True
colors = new_colors
break
colors = new_colors
buckets = self._buckets_from_colors(colors)
cells = [
sorted(nodes, key=str)
for _, nodes in sorted(
buckets.items(),
key=lambda kv: (kv[0], tuple(map(str, kv[1]))),
)
]
orbits = self._orbits_from_buckets(buckets)
color_hist = {c: len(nodes) for c, nodes in buckets.items()}
canonical_order = self._canonical_order_from_colors(self.G, colors)
approx_count: Optional[int] = None
if self.estimate_automorphisms:
approx_count = approx_automorphism_count_from_cells(
cells,
cap=self.automorphism_cap,
)
self._state_cache = _WLState(
colors=colors,
cells=cells,
orbits=orbits,
color_hist=color_hist,
iters_run=iters_run,
stabilized=stabilized,
canonical_order=canonical_order,
approx_automorphism_count=approx_count,
)
self._summary_cache = None
self._fast_signature = None
self._cache_key_last = key
return self._state_cache
[docs]
def colors(self) -> Dict[Any, str]:
"""
Return final WL colors.
:returns:
Mapping from node to final color.
:rtype: Dict[Any, str]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.colors())
"""
return dict(self._run().colors)
[docs]
def color_of(self, v: Any) -> str:
"""
Return the final WL color of one node.
:param v:
Node identifier.
:type v: Any
:returns:
Final WL color.
:rtype: str
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.color_of(1))
"""
return self._run().colors[v]
[docs]
def orbits(self) -> List[Set[Any]]:
"""
Return approximate WL orbit sets.
:returns:
Approximate orbits induced by final WL colors.
:rtype: List[Set[Any]]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.orbits())
"""
return [set(x) for x in self._run().orbits]
[docs]
def wl_orbits(self) -> List[Set[Any]]:
"""
Alias for :meth:`orbits`.
:returns:
Approximate WL orbit sets.
:rtype: List[Set[Any]]
"""
return self.orbits()
[docs]
def has_nontrivial_automorphism(self) -> bool:
"""
Heuristically detect whether symmetry may be present.
This is approximate and simply checks whether any WL color cell has size
greater than one.
:returns:
``True`` if WL detects a non-singleton cell, else ``False``.
:rtype: bool
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.has_nontrivial_automorphism())
"""
return any(len(cell) > 1 for cell in self._run().cells)
[docs]
def canonical_order(self) -> List[Any]:
"""
Return the deterministic WL node order.
:returns:
WL-based canonical node order.
:rtype: List[Any]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.canonical_order())
"""
return list(self._run().canonical_order)
[docs]
def canonical_key(self) -> Any:
"""
Return the canonical key induced by the WL order.
:returns:
WL canonical key.
:rtype: Any
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.canonical_key())
"""
return graph_key_from_order(self.G, self.canonical_order(), self.config)
[docs]
def canonical_graph(self) -> nx.DiGraph:
"""
Return the canonically relabeled graph using the WL order.
:returns:
WL-canonically relabeled graph.
:rtype: nx.DiGraph
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
G_can = wl.canonical_graph()
print(sorted(G_can.nodes()))
"""
order = self.canonical_order()
mapping = {v: i + 1 for i, v in enumerate(order)}
return nx.relabel_nodes(self.G, mapping, copy=True)
[docs]
def graph(self) -> nx.DiGraph:
"""
Alias for :meth:`canonical_graph`, matching the older canon style.
:returns:
WL-canonically relabeled graph.
:rtype: nx.DiGraph
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
G_can = wl.graph()
"""
return self.canonical_graph()
[docs]
def canonical_result(self) -> WLCanonicalResult:
"""
Build an approximate canonicalization result in a CRN-canon-like format.
:returns:
Approximate canonicalization result.
:rtype: WLCanonicalResult
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
result = wl.canonical_result()
print(result.canonical_order)
print(result.automorphism_count)
"""
if self._summary_cache is not None:
return self._summary_cache
start = perf_counter()
state = self._run()
can_graph = self.canonical_graph()
can_key = graph_key_from_order(self.G, state.canonical_order, self.config)
self._summary_cache = WLCanonicalResult(
canon_graph=can_graph,
graph_type=self.graph_type,
canonical_order=list(state.canonical_order),
canonical_key=can_key,
automorphism_count=state.approx_automorphism_count,
orbits=[set(x) for x in state.orbits],
colors=dict(state.colors),
color_hist=dict(state.color_hist),
iters_run=state.iters_run,
stabilized=state.stabilized,
exact=False,
elapsed_seconds=perf_counter() - start,
)
return self._summary_cache
[docs]
def summary(self) -> Dict[str, Any]:
"""
Return a dictionary summary in a format close to the exact canonicalizer.
The reported automorphism count and orbit sets are WL-based
approximations.
:returns:
Summary dictionary.
:rtype: Dict[str, Any]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
info = wl.summary()
print(info["automorphism_count"])
print(info["orbits"])
"""
res = self.canonical_result()
return {
"canon_graph": res.canon_graph,
"graph_type": res.graph_type,
"automorphism_count": res.automorphism_count,
"orbits": res.orbits,
"canonical_perm": res.canonical_order,
"canonical_key": res.canonical_key,
"colors": res.colors,
"color_hist": res.color_hist,
"iters_run": res.iters_run,
"stabilized": res.stabilized,
"exact": res.exact,
"elapsed_seconds": res.elapsed_seconds,
}
[docs]
def fast_signature(self) -> Tuple[Any, ...]:
"""
Return a fast graph signature using graph statistics and WL color
histogram.
This is useful as a cheap prefilter before exact graph isomorphism or
exact canonicalization.
:returns:
Fast graph signature.
:rtype: Tuple[Any, ...]
Example
-------
.. code-block:: python
wl = WLCanonicalizer(syn)
print(wl.fast_signature())
"""
if self._fast_signature is None:
state = self._run()
self._fast_signature = build_fast_signature(
self.G,
self.graph_type,
self.config,
wl_color_hist=state.color_hist,
)
return self._fast_signature
[docs]
def wl_canonical(source: Any, **kwargs: Any) -> nx.DiGraph:
"""
Convenience function returning the WL-canonically relabeled graph.
:param source:
Input CRN representation.
:type source: Any
:param kwargs:
Additional keyword arguments forwarded to :class:`WLCanonicalizer`.
:type kwargs: Any
:returns:
WL-canonically relabeled graph.
:rtype: nx.DiGraph
Example
-------
.. code-block:: python
from synkit.CRN.Sym import SymmetryConfig, wl_canonical
G_can = wl_canonical(
syn.to_digraph(),
include_rule=True,
config=SymmetryConfig.topological(),
)
"""
return WLCanonicalizer(source, **kwargs).canonical_graph()