from typing import Optional
import networkx as nx
import hashlib
from synkit.IO import setup_logging
logger = setup_logging()
[docs]
class NautyCanonicalizer:
"""Perform Nauty-style canonicalization of a NetworkX graph, optionally
refining and distinguishing nodes and edges by specified attributes, and
extracting automorphisms, orbits, and canonical permutations.
:param node_attrs: List of node attribute keys to include in the
initial partition refinement. Nodes sharing the same tuple of
values under these keys will start in the same cell.
:type node_attrs: list[str] | None
:param edge_attrs: List of edge attribute keys to include when
distinguishing edges in the canonical label. If an edge has
none of these keys, its contribution will be empty.
:type edge_attrs: list[str] | None
"""
__slots__ = ("node_attrs", "edge_attrs")
def __init__(
self,
node_attrs: Optional[list[str]] = None,
edge_attrs: Optional[list[str]] = None,
) -> None:
"""Initialize the NautyCanonicalizer.
:param node_attrs: Node attribute names to use for initial
partitioning.
:type node_attrs: list[str] | None
:param edge_attrs: Edge attribute names to include in the
canonical label.
:type edge_attrs: list[str] | None
"""
self.node_attrs = list(node_attrs) if node_attrs else []
self.edge_attrs = list(edge_attrs) if edge_attrs else []
@staticmethod
def _freeze(x):
if isinstance(x, list):
return tuple(NautyCanonicalizer._freeze(v) for v in x)
if isinstance(x, dict):
return frozenset(
(k, NautyCanonicalizer._freeze(v)) for k, v in sorted(x.items())
)
return x
def _update_atom_map(self, G):
for n in G.nodes():
G.nodes[n]["atom_map"] = n
def _initial_partition(self, G):
if not self.node_attrs:
return [sorted(G.nodes())]
buckets = {}
for v in G.nodes():
key = tuple(
self._freeze(G.nodes[v].get(attr, None)) for attr in self.node_attrs
)
buckets.setdefault(key, []).append(v)
return [sorted(nodes) for _, nodes in sorted(buckets.items())]
def _node_signature(self, G, v, partition):
node_attrs = tuple(
self._freeze(G.nodes[v].get(a, None)) for a in self.node_attrs
)
degree = G.degree[v]
nbr_part_counts = []
for cell in partition:
count = sum(1 for nbr in G.neighbors(v) if nbr in cell)
nbr_part_counts.append(count)
nbr_part_counts = tuple(nbr_part_counts)
edge_attr_multiset = []
for nbr in G.neighbors(v):
attrs = G[v][nbr]
edge_attrs = []
for a in self.edge_attrs:
val = attrs.get(a, None)
if a == "order" and isinstance(val, tuple):
val = tuple(sorted(round(float(x), 3) for x in val))
edge_attrs.append(self._freeze(val))
edge_attr_multiset.append(tuple(edge_attrs))
edge_attr_multiset = tuple(sorted(edge_attr_multiset))
return (node_attrs, degree, nbr_part_counts, edge_attr_multiset)
def _refine(self, G, partition):
changed = True
while changed:
changed = False
new_partition = []
sig_cache = {}
for cell in partition:
if len(cell) <= 1:
new_partition.append(cell)
continue
sigs = {}
for v in cell:
if v not in sig_cache:
sig_cache[v] = self._node_signature(G, v, partition)
sig = sig_cache[v]
sigs.setdefault(sig, []).append(v)
if len(sigs) > 1:
changed = True
for sig in sorted(sigs):
new_partition.append(sorted(sigs[sig]))
else:
new_partition.append(cell)
partition = new_partition
return partition
def _search(self, G, partition, prefix, best, aut_perms, depth=0, max_depth=None):
if max_depth is not None and depth > max_depth:
logger.debug(
f"Early stopping at depth {depth} due to max_depth={max_depth}"
)
return True # early stop triggered
partition = self._refine(G, partition)
if all(len(c) == 1 for c in partition):
perm = prefix + [v for c in partition for v in c]
label = self._build_label(G, perm)
if best["label"] is None or label < best["label"]:
best["label"], best["perm"] = label, perm
aut_perms.clear()
aut_perms.append(perm)
logger.debug(f"New best label found at depth {depth}")
elif label == best["label"]:
aut_perms.append(perm)
logger.debug(f"Equivalent label found at depth {depth}")
return False
idx = next(i for i, c in enumerate(partition) if len(c) > 1)
cell = partition[idx]
sorted_cell = sorted(cell, key=lambda n: G.nodes[n].get("atom_map", n))
for v in sorted_cell:
rest = [w for w in cell if w != v]
# fmt: off
new_partition = (
partition[:idx]
+ [[v]]
+ ([sorted(rest)] if rest else [])
+ partition[idx + 1:]
)
# fmt: on
candidate_prefix = prefix + [v]
partial_label = self._build_partial_label(G, candidate_prefix)
if best["label"] is not None and partial_label > best["label"]:
logger.debug(f"Pruning branch at depth {depth} due to partial label")
continue # prune branch early
if self._search(
G,
new_partition,
candidate_prefix,
best,
aut_perms,
depth=depth + 1,
max_depth=max_depth,
):
return True # propagate early stop upward
return False
def _build_label(self, G, perm):
node_segment = "|".join(
":".join(
str(self._freeze(G.nodes[v].get(attr, ""))) for attr in self.node_attrs
)
for v in perm
)
n = len(perm)
edge_bits = []
for i in range(n):
vi = perm[i]
for j in range(i + 1, n):
vj = perm[j]
if G.has_edge(vi, vj):
attrs = G[vi][vj]
frozen_attrs = tuple(
self._freeze(attrs.get(a, "")) for a in self.edge_attrs
)
edge_bits.append("1:" + ":".join(str(x) for x in frozen_attrs))
else:
edge_bits.append("0:" + ":".join("" for _ in self.edge_attrs))
edge_segment = "|".join(edge_bits)
return node_segment + "||" + edge_segment
def _build_partial_label(self, G, prefix):
node_segment = "|".join(
":".join(
str(self._freeze(G.nodes[v].get(attr, ""))) for attr in self.node_attrs
)
for v in prefix
)
suffix = "{" * 1000 # lexicographically larger than any label char
return node_segment + suffix
[docs]
def compute_orbits(self, aut_perms):
if not aut_perms:
return []
orbit_map = {}
orbits = []
def union_orbits(i, j):
if i == j:
return
o1 = orbits[i]
o2 = orbits[j]
if len(o1) < len(o2):
i, j = j, i
o1, o2 = o2, o1
o1.update(o2)
orbits[j] = set()
for v in o2:
orbit_map[v] = i
first_perm = aut_perms[0]
for idx, node in enumerate(first_perm):
orbit_map[node] = idx
orbits.append({node})
for perm in aut_perms:
for idx, node in enumerate(perm):
union_orbits(idx, orbit_map[node])
return [o for o in orbits if o]
[docs]
def graph_signature(self, G):
G_canon = self.canonical_form(G)
label = self._build_label(G_canon, sorted(G_canon.nodes()))
return hashlib.sha256(label.encode("utf-8")).hexdigest()