Source code for synkit.Chem.Reaction.Mapper.wl_mapper

from __future__ import annotations

from dataclasses import dataclass, field, asdict
from typing import (
    Any,
    Dict,
    Hashable,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
    Set,
)
import hashlib
import itertools
import logging
import math
import time

from synkit.IO import rsmi_to_graph, graph_to_smi

try:
    from synkit.Chem.Reaction.canon_rsmi import CanonRSMI  # type: ignore
except Exception:
    CanonRSMI = None  # type: ignore

try:
    from synkit.Graph.ITS.its_construction import ITSConstruction  # type: ignore
    from synkit.Graph.ITS.its_decompose import get_rc  # type: ignore
except Exception:
    ITSConstruction = None  # type: ignore
    get_rc = None  # type: ignore


_NodeLabelKeys = Union[Tuple[str, ...], List[str], str]


def _safe_int(x: Any) -> int:
    try:
        return int(x)
    except Exception:
        return 0


def _norm_pair(u: Hashable, v: Hashable) -> Tuple[Hashable, Hashable]:
    return (u, v) if repr(u) <= repr(v) else (v, u)


def _stable_sort_key(x: Hashable) -> str:
    return f"{type(x).__name__}:{repr(x)}"


def _u64_from_bytes(b: bytes) -> int:
    return int.from_bytes(b[:8], "little", signed=False) if b else 0


def _u64_to_bytes(x: int) -> bytes:
    return int(x & ((1 << 64) - 1)).to_bytes(8, "little", signed=False)


def _token_bytes(tok: Any) -> bytes:
    if tok is None:
        return b""
    if isinstance(tok, bytes):
        return tok
    if isinstance(tok, str):
        return tok.encode("utf-8", errors="ignore")
    if isinstance(tok, bool):
        return b"1" if tok else b"0"
    if isinstance(tok, int):
        return f"i:{tok}".encode("utf-8")
    if isinstance(tok, float):
        return f"f:{tok:.6g}".encode("utf-8")
    return repr(tok).encode("utf-8", errors="ignore")


def _is_double_token(tok: Any) -> bool:
    if tok is None:
        return False
    if isinstance(tok, (int, float)):
        return abs(float(tok) - 2.0) < 1e-6
    if isinstance(tok, str):
        s = tok.strip().lower()
        if s in {"2", "double", "d"}:
            return True
        try:
            return abs(float(s) - 2.0) < 1e-6
        except Exception:
            return False
    return False


def _is_single_token(tok: Any) -> bool:
    if tok is None:
        return False
    if isinstance(tok, (int, float)):
        return abs(float(tok) - 1.0) < 1e-6
    if isinstance(tok, str):
        s = tok.strip().lower()
        if s in {"1", "single", "s"}:
            return True
        try:
            return abs(float(s) - 1.0) < 1e-6
        except Exception:
            return False
    return False


def _multiset_l1_sorted(a: List[int], b: List[int]) -> int:
    i = 0
    j = 0
    da = len(a)
    db = len(b)
    miss = 0
    while i < da and j < db:
        va = a[i]
        vb = b[j]
        if va == vb:
            ca = 1
            cb = 1
            i += 1
            j += 1
            while i < da and a[i] == va:
                ca += 1
                i += 1
            while j < db and b[j] == vb:
                cb += 1
                j += 1
            miss += abs(ca - cb)
            continue
        if va < vb:
            ca = 1
            i += 1
            while i < da and a[i] == va:
                ca += 1
                i += 1
            miss += ca
            continue
        cb = 1
        j += 1
        while j < db and b[j] == vb:
            cb += 1
            j += 1
        miss += cb
    while i < da:
        va = a[i]
        ca = 1
        i += 1
        while i < da and a[i] == va:
            ca += 1
            i += 1
        miss += ca
    while j < db:
        vb = b[j]
        cb = 1
        j += 1
        while j < db and b[j] == vb:
            cb += 1
            j += 1
        miss += cb
    return int(miss)


[docs] @dataclass(frozen=True) class WLMapperConfig: # WL hashing iterations: int = 4 digest_size: int = 16 include_initial: bool = True edge_attr: str = "order" node_label_keys: _NodeLabelKeys = ("element",) progressive_fallback: bool = True normalize_aromatic_bonds: bool = True # Candidate exploration (edge masks) enable_bond_cut: bool = True max_bond_cut_size: int = 2 max_candidates: int = 200 candidate_edge_pool: int = 16 time_limit_s: Optional[float] = 2.0 # Allow deeper masks (still PMCD-first; only used to *generate* candidates) enable_heuristic_scoring: bool = True heuristic_max_cut_size: int = 4 heuristic_candidate_budget: int = 120 # PMCD objective: strict minimization key (lexicographic) pmcd_unmapped_weight: int = 1 pmcd_bond_weight: int = 1 pmcd_hcount_weight: int = 1 # Heuristic tie-break objective (ONLY after PMCD) heuristic_carbonyl_double_penalty: float = 50.0 heuristic_aromatic_c_break_penalty: float = 0.75 bc_cost_acyl_co: float = 0.10 bc_cost_x_deg3_o: float = 0.35 bc_cost_x_deg2_o: float = 0.55 bc_cost_x_deg1_o: float = 0.75 bc_cost_aromatic_co: float = 1.50 bc_cost_other: float = 1.00 bc_cost_order_mismatch_scale: float = 0.60 # Peracid/peroxyacyl: prefer O-O cleavage; avoid acyl C-O cleavage in C(=O)-O-O bc_cost_peroxy_oo: float = 0.05 bc_cost_acyl_co_peroxy: float = 3.00 # Reaction-center restriction for PMCD stats (optional) rc_only_bond_changes: bool = True rc_only_hcount_changes: bool = True rc_expand_hops: int = 1 # Refinements enable_rc_refine: bool = True rc_distance_weight: float = 0.5 enable_symmetry_pruning: bool = True symmetry_depth: int = 4 large_bucket_threshold: int = 25 greedy_topk_per_u: int = 10 hungarian_max_size: int = 10 enable_dynamic_wl: bool = True enable_swap_refine: bool = True swap_refine_max_iter: int = 10 swap_refine_class_depth: int = 4 swap_refine_max_group_size: int = 14 # Output multi_solutions: bool = True max_solutions: int = ( 6 # returned solutions (PMCD-min set will be trimmed by heuristic ordering) ) solution_score_slack: float = ( 0.0 # not used in PMCD-first mode (kept for compatibility) ) start_atom_map: int = 1 unmapped_value: int = 0 assign_maps_to_unmapped: bool = True use_its_final: bool = True drop_non_aam: bool = False use_index_as_atom_map: bool = False
[docs] def validated(self) -> "WLMapperConfig": keys = self._normalize_node_keys(self.node_label_keys) self._validate(keys) if keys != self.node_label_keys: d = asdict(self) d["node_label_keys"] = keys return WLMapperConfig(**d) return self
@staticmethod def _normalize_node_keys(keys: _NodeLabelKeys) -> Tuple[str, ...]: if isinstance(keys, str): return (keys,) if isinstance(keys, list): return tuple(keys) if isinstance(keys, tuple): return keys raise ValueError("node_label_keys must be str/list/tuple") def _validate(self, keys: Tuple[str, ...]) -> None: if not keys: raise ValueError("node_label_keys must be non-empty") if self.iterations < 1: raise ValueError("iterations must be >= 1") if self.digest_size < 8: raise ValueError("digest_size must be >= 8") if self.max_candidates < 1: raise ValueError("max_candidates must be >= 1") if self.candidate_edge_pool < 1: raise ValueError("candidate_edge_pool must be >= 1") if self.max_bond_cut_size < 0: raise ValueError("max_bond_cut_size must be >= 0") if self.time_limit_s is not None and self.time_limit_s < 0: raise ValueError("time_limit_s must be >= 0/None") if self.symmetry_depth < 1: raise ValueError("symmetry_depth must be >= 1") if self.large_bucket_threshold < 2: raise ValueError("large_bucket_threshold must be >= 2") if self.start_atom_map < 1: raise ValueError("start_atom_map must be >= 1") if self.hungarian_max_size < 2: raise ValueError("hungarian_max_size must be >= 2") if self.greedy_topk_per_u < 1: raise ValueError("greedy_topk_per_u must be >= 1") for v in ( ( self.heuristic_extra_penalty if hasattr(self, "heuristic_extra_penalty") else 0.0 ), # backward-compat self.heuristic_aromatic_c_break_penalty, self.heuristic_carbonyl_double_penalty, self.bc_cost_acyl_co, self.bc_cost_x_deg3_o, self.bc_cost_x_deg2_o, self.bc_cost_x_deg1_o, self.bc_cost_aromatic_co, self.bc_cost_other, self.bc_cost_order_mismatch_scale, self.bc_cost_peroxy_oo, self.bc_cost_acyl_co_peroxy, ): if v < 0: raise ValueError("weights must be >= 0")
[docs] def as_dict(self) -> Dict[str, Any]: return asdict(self)
[docs] @dataclass class MappingResult: mapping: Dict[Hashable, Hashable] # NOTE: score is PMCD numeric proxy (for convenience). True minimization uses pmcd_key. score: float meta: Dict[str, Any] = field(default_factory=dict)
[docs] @dataclass class Solution: result: MappingResult mapped_rsmi: str
[docs] @dataclass(frozen=True) class EdgeMaskCandidate: side: str removed_pairs: frozenset prior_score: float = 0.0 # kept for API-compat; NOT used in PMCD stage meta: Dict[str, Any] = field(default_factory=dict)
class _CanonAdapter: def canonicalise(self, rsmi: str) -> str: if CanonRSMI is None: return rsmi try: return CanonRSMI().canonicalise(rsmi).canonical_rsmi except Exception: return rsmi class _SymmetryPruner: def __init__(self, cfg: WLMapperConfig) -> None: self._cfg = cfg def node_classes( self, hmap: Dict[Hashable, List[bytes]] ) -> Dict[Hashable, Tuple[bytes, ...]]: depth = int(self._cfg.symmetry_depth) out: Dict[Hashable, Tuple[bytes, ...]] = {} for u, seq in hmap.items(): out[u] = tuple(seq[: min(depth, len(seq))]) return out def unique_edges( self, edges: List[Tuple[Hashable, Hashable]], nclass: Dict[Hashable, Tuple[bytes, ...]], edge_key_fn, ) -> List[Tuple[Hashable, Hashable]]: seen: set = set() out: List[Tuple[Hashable, Hashable]] = [] for u, v in edges: cu = nclass.get(u) cv = nclass.get(v) if cu is None or cv is None: out.append((u, v)) continue a, b = (cu, cv) if repr(cu) <= repr(cv) else (cv, cu) sig = (a, b, edge_key_fn(u, v)) if sig in seen: continue seen.add(sig) out.append((u, v)) return out
[docs] class GraphCache: def __init__(self, G: Any, cfg: WLMapperConfig) -> None: self.G = G self.cfg = cfg self.nodes: List[Hashable] = list(G.nodes) self.node_set: Set[Hashable] = set(self.nodes) self._node_data: Dict[Hashable, Dict[str, Any]] = { u: d for u, d in G.nodes(data=True) } self._neighbors: Dict[Hashable, Tuple[Hashable, ...]] = { u: tuple(G.neighbors(u)) for u in self.nodes } self._neighbor_set: Dict[Hashable, Set[Hashable]] = { u: set(self._neighbors[u]) for u in self.nodes } self._element: Dict[Hashable, Any] = { u: self._node_data[u].get("element") for u in self.nodes } self._aromatic: Dict[Hashable, bool] = { u: bool(self._node_data[u].get("aromatic")) for u in self.nodes } self._hcount: Dict[Hashable, int] = { u: _safe_int(self._node_data[u].get("hcount", 0)) for u in self.nodes } self._charge: Dict[Hashable, int] = { u: _safe_int(self._node_data[u].get("charge", 0)) for u in self.nodes } self._edge_orders: Dict[Tuple[Hashable, Hashable], Tuple[Any, ...]] = {} self._edge_label_one: Dict[Tuple[Hashable, Hashable], Any] = {} self._edge_code_u64: Dict[Tuple[Hashable, Hashable], int] = {} self._edge_is_single: Dict[Tuple[Hashable, Hashable], bool] = {} self._edge_is_double: Dict[Tuple[Hashable, Hashable], bool] = {} self._neigh_codes: Dict[Hashable, List[int]] = {} self._carbonyl_c: Set[Hashable] = set() self._acyl_oxygen: Set[Hashable] = set() self._acyl_oxygen_peroxy: Set[Hashable] = set() self._build_edge_caches() self._build_carbonyl_cache() self._build_acyl_oxygen_cache() self._build_neighbor_codes()
[docs] def node_data(self, u: Hashable) -> Dict[str, Any]: return self._node_data[u]
[docs] def neighbors(self, u: Hashable) -> Tuple[Hashable, ...]: return self._neighbors.get(u, ())
[docs] def neighbors_set(self, u: Hashable) -> Set[Hashable]: return self._neighbor_set.get(u, set())
[docs] def element(self, u: Hashable) -> Any: return self._element.get(u)
[docs] def aromatic(self, u: Hashable) -> bool: return bool(self._aromatic.get(u, False))
[docs] def hcount(self, u: Hashable) -> int: return int(self._hcount.get(u, 0))
[docs] def charge(self, u: Hashable) -> int: return int(self._charge.get(u, 0))
[docs] def degree(self, u: Hashable) -> int: return int(len(self.neighbors(u)))
[docs] def edge_orders(self, u: Hashable, v: Hashable) -> Tuple[Any, ...]: a, b = _norm_pair(u, v) return self._edge_orders.get((a, b), ())
[docs] def edge_label_one(self, u: Hashable, v: Hashable) -> Any: a, b = _norm_pair(u, v) return self._edge_label_one.get((a, b))
[docs] def edge_code_u64(self, u: Hashable, v: Hashable) -> int: a, b = _norm_pair(u, v) return int(self._edge_code_u64.get((a, b), 0))
[docs] def edge_is_single(self, u: Hashable, v: Hashable) -> bool: a, b = _norm_pair(u, v) return bool(self._edge_is_single.get((a, b), False))
[docs] def edge_is_double(self, u: Hashable, v: Hashable) -> bool: a, b = _norm_pair(u, v) return bool(self._edge_is_double.get((a, b), False))
[docs] def neigh_codes(self, u: Hashable) -> List[int]: return self._neigh_codes.get(u, [])
[docs] def is_carbonyl_c(self, u: Hashable) -> bool: return u in self._carbonyl_c
[docs] def is_acyl_oxygen(self, o: Hashable) -> bool: return o in self._acyl_oxygen
[docs] def is_acyl_oxygen_peroxy(self, o: Hashable) -> bool: return o in self._acyl_oxygen_peroxy
# ---------- internal builds ---------- def _build_edge_caches(self) -> None: for u, v in self.G.edges(): a, b = _norm_pair(u, v) if (a, b) in self._edge_orders: continue toks = self._edge_orders_from_get_edge_data(a, b) self._edge_orders[(a, b)] = toks tok1 = toks[0] if toks else None self._edge_label_one[(a, b)] = tok1 self._edge_code_u64[(a, b)] = self._stable_u64_for_token(tok1) self._edge_is_single[(a, b)] = _is_single_token(tok1) self._edge_is_double[(a, b)] = _is_double_token(tok1) def _edge_orders_from_get_edge_data( self, u: Hashable, v: Hashable ) -> Tuple[Any, ...]: data = self.G.get_edge_data(u, v, default=None) if data is None: return () if isinstance(data, dict) and any(isinstance(k, int) for k in data.keys()): toks: List[Any] = [] for d in data.values(): if not isinstance(d, dict): continue raw = d.get(self.cfg.edge_attr) tok = self._edge_order_token(u, v, raw) if tok is not None: toks.append(tok) if not toks: return () return tuple(sorted(toks, key=repr)) if isinstance(data, dict): raw = data.get(self.cfg.edge_attr) tok = self._edge_order_token(u, v, raw) return (tok,) if tok is not None else () return () def _edge_order_token(self, u: Hashable, v: Hashable, o: Any) -> Any: if o is None: return "aromatic" if self._is_aromatic_edge(u, v) else None if self._is_aromatic_edge(u, v): return "aromatic" if isinstance(o, str): s = o.lower() return "aromatic" if "arom" in s else o if isinstance(o, (int, float)): val = float(o) if abs(val - 1.5) < 1e-3: return "aromatic" return round(val, 3) if isinstance(o, float) else o return o def _is_aromatic_edge(self, u: Hashable, v: Hashable) -> bool: if not self.cfg.normalize_aromatic_bonds: return False return self.aromatic(u) and self.aromatic(v) def _stable_u64_for_token(self, tok: Any) -> int: b = _token_bytes(tok) h = hashlib.blake2b(digest_size=8) h.update(b"edge:") h.update(b) return _u64_from_bytes(h.digest()) def _stable_u32_for_neighbor_key(self, tok: Any, v: Hashable) -> int: h = hashlib.blake2b(digest_size=8) h.update(b"n:") h.update(_token_bytes(tok)) h.update(b"|") h.update(_token_bytes(self.element(v))) h.update(b"|") h.update(b"1" if self.aromatic(v) else b"0") h.update(b"|") h.update(str(self.charge(v)).encode("utf-8")) h.update(b"|") h.update(str(self.hcount(v)).encode("utf-8")) return int.from_bytes(h.digest()[:4], "little", signed=False) def _build_neighbor_codes(self) -> None: out: Dict[Hashable, List[int]] = {} for u in self.nodes: codes: List[int] = [] for v in self.neighbors(u): tok = self.edge_label_one(u, v) codes.append(self._stable_u32_for_neighbor_key(tok, v)) codes.sort() out[u] = codes self._neigh_codes = out def _build_carbonyl_cache(self) -> None: carbonyl: Set[Hashable] = set() for u in self.nodes: if self.element(u) != "C": continue for v in self.neighbors(u): if self.element(v) != "O": continue if self.edge_is_double(u, v): carbonyl.add(u) break self._carbonyl_c = carbonyl def _build_acyl_oxygen_cache(self) -> None: acyl_o: Set[Hashable] = set() acyl_o_peroxy: Set[Hashable] = set() for c in self._carbonyl_c: for o in self.neighbors(c): if self.element(o) != "O": continue if self.edge_is_double(c, o): continue acyl_o.add(o) # peroxyacyl oxygen: acyl oxygen bonded to another oxygen (C(=O)-O-O...) for o in acyl_o: for n in self.neighbors(o): if self.element(n) == "O" and n != o and self.edge_orders(o, n): acyl_o_peroxy.add(o) break self._acyl_oxygen = acyl_o self._acyl_oxygen_peroxy = acyl_o_peroxy
[docs] class MaskView: def __init__(self, cache: GraphCache, removed_pairs: Optional[frozenset]) -> None: self.cache = cache self.removed_pairs = removed_pairs self._over_neigh_set: Dict[Hashable, Set[Hashable]] = {} self._over_deg: Dict[Hashable, int] = {} self._over_neigh_codes: Dict[Hashable, List[int]] = {} if removed_pairs: self._build_overrides(removed_pairs)
[docs] def masked(self, u: Hashable, v: Hashable) -> bool: if not self.removed_pairs: return False a, b = _norm_pair(u, v) return (a, b) in self.removed_pairs
[docs] def neighbors(self, u: Hashable) -> Tuple[Hashable, ...]: if not self.removed_pairs: return self.cache.neighbors(u) if u not in self._over_neigh_set: return tuple(v for v in self.cache.neighbors(u) if not self.masked(u, v)) return tuple(self._over_neigh_set[u])
[docs] def neighbors_set(self, u: Hashable) -> Set[Hashable]: if not self.removed_pairs: return self.cache.neighbors_set(u) if u in self._over_neigh_set: return self._over_neigh_set[u] return self.cache.neighbors_set(u)
[docs] def degree(self, u: Hashable) -> int: if not self.removed_pairs: return self.cache.degree(u) if u in self._over_deg: return self._over_deg[u] return self.cache.degree(u)
[docs] def edge_orders(self, u: Hashable, v: Hashable) -> Tuple[Any, ...]: return () if self.masked(u, v) else self.cache.edge_orders(u, v)
[docs] def edge_label_one(self, u: Hashable, v: Hashable) -> Any: return None if self.masked(u, v) else self.cache.edge_label_one(u, v)
[docs] def edge_code_u64(self, u: Hashable, v: Hashable) -> int: return 0 if self.masked(u, v) else self.cache.edge_code_u64(u, v)
[docs] def edge_is_single(self, u: Hashable, v: Hashable) -> bool: return False if self.masked(u, v) else self.cache.edge_is_single(u, v)
[docs] def edge_is_double(self, u: Hashable, v: Hashable) -> bool: return False if self.masked(u, v) else self.cache.edge_is_double(u, v)
[docs] def neigh_codes(self, u: Hashable) -> List[int]: if not self.removed_pairs: return self.cache.neigh_codes(u) if u not in self._over_neigh_set: return self.cache.neigh_codes(u) if u in self._over_neigh_codes: return self._over_neigh_codes[u] codes = self._compute_neigh_codes_masked(u) self._over_neigh_codes[u] = codes return codes
def _compute_neigh_codes_masked(self, u: Hashable) -> List[int]: codes: List[int] = [] for v in self.neighbors(u): tok = self.edge_label_one(u, v) codes.append(self.cache._stable_u32_for_neighbor_key(tok, v)) codes.sort() return codes def _build_overrides(self, removed_pairs: frozenset) -> None: affected: Set[Hashable] = set() for a, b in removed_pairs: affected.add(a) affected.add(b) for u in affected: if u not in self.cache.node_set: continue self._over_neigh_set[u] = set(self.cache.neighbors_set(u)) for a, b in removed_pairs: if a in self._over_neigh_set: self._over_neigh_set[a].discard(b) if b in self._over_neigh_set: self._over_neigh_set[b].discard(a) for u, ns in self._over_neigh_set.items(): self._over_deg[u] = len(ns)
class _FastWLHasher: def __init__(self, cfg: WLMapperConfig) -> None: self.cfg = cfg def full_hashes( self, view: MaskView, init_labels: Dict[Hashable, bytes], blake_fn, ) -> Dict[Hashable, List[bytes]]: out: Dict[Hashable, List[bytes]] = {u: [] for u in view.cache.nodes} prev: Dict[Hashable, bytes] = {} if self.cfg.include_initial: for u in view.cache.nodes: hu = blake_fn(b"init:" + init_labels.get(u, b"")) out[u].append(hu) prev[u] = hu else: for u in view.cache.nodes: prev[u] = blake_fn(b"init:" + init_labels.get(u, b"")) for _ in range(int(self.cfg.iterations)): prev = self._one_round(view, prev, out) return out def masked_hashes( self, view: MaskView, init_labels: Dict[Hashable, bytes], base_wl: Dict[Hashable, List[bytes]], removed_pairs: frozenset, blake_fn, ) -> Tuple[Dict[Hashable, List[bytes]], Set[Hashable]]: if not removed_pairs: return base_wl, set() affected = self._affected_nodes(view.cache, removed_pairs) if not affected: return base_wl, set() out: Dict[Hashable, List[bytes]] = {} for u, seq in base_wl.items(): out[u] = list(seq) if u in affected else seq prev_aff: Dict[Hashable, bytes] = {} if self.cfg.include_initial: for u in affected: hu = blake_fn(b"init:" + init_labels.get(u, b"")) out[u][0] = hu prev_aff[u] = hu else: for u in affected: prev_aff[u] = blake_fn(b"init:" + init_labels.get(u, b"")) for r in range(1, int(self.cfg.iterations) + 1): prev_aff = self._one_round_partial(view, prev_aff, out, affected, r) return out, affected def _affected_nodes( self, cache: GraphCache, removed_pairs: frozenset ) -> Set[Hashable]: seeds: Set[Hashable] = set() for a, b in removed_pairs: if a in cache.node_set: seeds.add(a) if b in cache.node_set: seeds.add(b) if not seeds: return set() radius = int(self.cfg.iterations) seen = set(seeds) frontier = set(seeds) for _ in range(radius): nxt: Set[Hashable] = set() for u in frontier: for v in cache.neighbors(u): if v not in seen: seen.add(v) nxt.add(v) frontier = nxt if not frontier: break return seen def _one_round( self, view: MaskView, prev: Dict[Hashable, bytes], out: Dict[Hashable, List[bytes]], ) -> Dict[Hashable, bytes]: new_prev: Dict[Hashable, bytes] = {} for u in view.cache.nodes: hu = self._hash_node(view, u, prev) out[u].append(hu) new_prev[u] = hu return new_prev def _one_round_partial( self, view: MaskView, prev_aff: Dict[Hashable, bytes], out: Dict[Hashable, List[bytes]], affected: Set[Hashable], round_idx: int, ) -> Dict[Hashable, bytes]: new_prev: Dict[Hashable, bytes] = {} pos = round_idx if self.cfg.include_initial else (round_idx - 1) for u in affected: hu = self._hash_node_partial(view, u, prev_aff, out, affected, pos) out[u][pos] = hu new_prev[u] = hu return new_prev def _hash_node( self, view: MaskView, u: Hashable, prev: Dict[Hashable, bytes] ) -> bytes: xu = _u64_from_bytes(prev.get(u, b"")) acc_xor = 0 acc_sum = 0 for v in view.neighbors(u): ev = int(view.edge_code_u64(u, v)) xv2 = _u64_from_bytes(prev.get(v, b"")) c = (ev ^ xv2) & ((1 << 64) - 1) acc_xor ^= c acc_sum = (acc_sum + c) & ((1 << 64) - 1) h = hashlib.blake2b(digest_size=int(self.cfg.digest_size)) h.update(b"wl:") h.update(_u64_to_bytes(xu)) h.update(_u64_to_bytes(acc_xor)) h.update(_u64_to_bytes(acc_sum)) h.update(_u64_to_bytes(view.degree(u))) return h.digest() def _hash_node_partial( self, view: MaskView, u: Hashable, prev_aff: Dict[Hashable, bytes], out: Dict[Hashable, List[bytes]], affected: Set[Hashable], prev_pos: int, ) -> bytes: prev_u = prev_aff.get(u) if prev_u is None: prev_u = out[u][prev_pos - 1] if prev_pos - 1 >= 0 else b"" xu = _u64_from_bytes(prev_u) acc_xor = 0 acc_sum = 0 for v in view.neighbors(u): ev = int(view.edge_code_u64(u, v)) if v in affected: pv = prev_aff.get(v) if pv is None: pv = out[v][prev_pos - 1] if prev_pos - 1 >= 0 else b"" else: pv = out[v][prev_pos - 1] if prev_pos - 1 >= 0 else b"" xv2 = _u64_from_bytes(pv) c = (ev ^ xv2) & ((1 << 64) - 1) acc_xor ^= c acc_sum = (acc_sum + c) & ((1 << 64) - 1) h = hashlib.blake2b(digest_size=int(self.cfg.digest_size)) h.update(b"wl:") h.update(_u64_to_bytes(xu)) h.update(_u64_to_bytes(acc_xor)) h.update(_u64_to_bytes(acc_sum)) h.update(_u64_to_bytes(view.degree(u))) return h.digest()
[docs] class WLMapper: """ PMCD-first mapping: 1) Enumerate candidate masks; for each candidate compute mapping and its PMCD key. 2) Keep the PMCD-minimal set (can be multiple solutions). 3) Apply chemical heuristic ONLY to choose optimal among PMCD-minimal. """ def __init__( self, config: WLMapperConfig = WLMapperConfig(), logger: Optional[logging.Logger] = None, ) -> None: self.cfg = config.validated() self._log = logger or logging.getLogger(__name__) self._canon = _CanonAdapter() self._sym = _SymmetryPruner(self.cfg) self._hasher = _FastWLHasher(self.cfg) self._rG: Any = None self._pG: Any = None self._r_cache: Optional[GraphCache] = None self._p_cache: Optional[GraphCache] = None self._mapped_rsmi: Optional[str] = None self._mapping: Dict[Hashable, Hashable] = {} self._best: Optional[MappingResult] = None self._solutions: List[Solution] = [] self._bond_change_report: List[Dict[str, Any]] = [] self._r_init: Dict[Hashable, bytes] = {} self._p_init: Dict[Hashable, bytes] = {} self._r_wl: Dict[Hashable, List[bytes]] = {} self._p_wl: Dict[Hashable, List[bytes]] = {} self._r_class: Dict[Hashable, Tuple[bytes, ...]] = {} self._p_class: Dict[Hashable, Tuple[bytes, ...]] = {} # -------- public API --------
[docs] def reset(self) -> "WLMapper": self._rG = None self._pG = None self._r_cache = None self._p_cache = None self._mapped_rsmi = None self._mapping = {} self._best = None self._solutions = [] self._bond_change_report = [] self._r_init = {} self._p_init = {} self._r_wl = {} self._p_wl = {} self._r_class = {} self._p_class = {} return self
[docs] def fit(self, rsmi: str) -> "WLMapper": self.reset() rG, pG = self._parse_rsmi(rsmi) self._rG, self._pG = rG, pG self._r_cache = GraphCache(rG, self.cfg) self._p_cache = GraphCache(pG, self.cfg) self._prime_baseline() best, pmcd_min_pool, full_pool = self._solve_pmcd_then_heuristic() self._best = best self._mapping = dict(best.mapping) # expose solutions: PMCD-min pool sorted by heuristic self._solutions = self._materialize_solutions(pmcd_min_pool) self._mapped_rsmi = ( self._solutions[0].mapped_rsmi if self._solutions else self._materialize_one(best) ) # report for chosen best self._bond_change_report = self._bond_change_report_for_mapping(best.mapping) return self
@property def mapped_rsmi(self) -> str: if self._mapped_rsmi is None: raise RuntimeError("WLMapper not fitted") return self._mapped_rsmi @property def aam(self) -> Dict[Hashable, Hashable]: return dict(self._mapping) @property def best_score(self) -> float: # numeric proxy of PMCD (still minimized by tuple key) if self._best is None: raise RuntimeError("WLMapper not fitted") return float(self._best.score) @property def best_pmcd_key(self) -> Tuple[int, int, int]: if self._best is None: raise RuntimeError("WLMapper not fitted") return tuple(self._best.meta.get("pmcd_key", (10**9, 10**9, 10**9))) # type: ignore[return-value] @property def best_heuristic_cost(self) -> float: if self._best is None: raise RuntimeError("WLMapper not fitted") return float(self._best.meta.get("heuristic_cost", 1e9)) @property def best_meta(self) -> Dict[str, Any]: return {} if self._best is None else dict(self._best.meta) @property def solutions(self) -> List[Solution]: return list(self._solutions) @property def bond_change_report(self) -> List[Dict[str, Any]]: return list(self._bond_change_report) # -------- parsing + baseline -------- def _parse_rsmi(self, rsmi: str) -> Tuple[Any, Any]: rG, pG = rsmi_to_graph( rsmi, drop_non_aam=self.cfg.drop_non_aam, use_index_as_atom_map=self.cfg.use_index_as_atom_map, ) if rG is None or pG is None: raise ValueError("rsmi_to_graph returned None") return rG, pG def _prime_baseline(self) -> None: if self._r_cache is None or self._p_cache is None: raise RuntimeError("cache missing") self._r_init = self._make_init_bytes(self._r_cache) self._p_init = self._make_init_bytes(self._p_cache) rv = MaskView(self._r_cache, None) pv = MaskView(self._p_cache, None) self._r_wl = self._hasher.full_hashes(rv, self._r_init, self._blake_bytes) self._p_wl = self._hasher.full_hashes(pv, self._p_init, self._blake_bytes) if self.cfg.enable_symmetry_pruning: self._r_class = self._sym.node_classes(self._r_wl) self._p_class = self._sym.node_classes(self._p_wl) def _make_init_bytes(self, cache: GraphCache) -> Dict[Hashable, bytes]: keys = WLMapperConfig._normalize_node_keys(self.cfg.node_label_keys) out: Dict[Hashable, bytes] = {} for u in cache.nodes: d = cache.node_data(u) s = "|".join(f"{k}={d.get(k, None)}" for k in keys) out[u] = s.encode("utf-8", errors="ignore") return out # -------- PMCD-first solve -------- def _solve_pmcd_then_heuristic( self, ) -> Tuple[MappingResult, List[MappingResult], List[MappingResult]]: """ Returns: best: chosen by (PMCD key) then heuristic cost pmcd_min_pool: all PMCD-min solutions (trimmed & ordered by heuristic) full_pool: all evaluated solutions (debug) """ t0 = time.time() full_pool: List[MappingResult] = [] # baseline base_map = self._map_two_pass(None) base_res = self._score_pmcd_and_heuristic(base_map, meta={"branch": "baseline"}) full_pool.append(base_res) best_pmcd_key = base_res.meta["pmcd_key"] pmcd_min: List[MappingResult] = [base_res] explored = 0 for cand in self._candidate_stream(): if explored >= int(self.cfg.max_candidates): break if self._time_exceeded(t0): break explored += 1 cand_map = self._map_two_pass(cand) res = self._score_pmcd_and_heuristic(cand_map, meta=dict(cand.meta)) full_pool.append(res) k = res.meta["pmcd_key"] if k < best_pmcd_key: best_pmcd_key = k pmcd_min = [res] elif k == best_pmcd_key: pmcd_min.append(res) # tie-break among PMCD-min by heuristic cost (and then by mapped_pairs desc) for r in pmcd_min: r.meta.setdefault("explored_candidates", explored) r.meta.setdefault("pmcd_min_set_size", len(pmcd_min)) pmcd_min.sort( key=lambda r: ( float(r.meta.get("heuristic_cost", 1e9)), -int(r.meta.get("mapped_pairs", 0)), float(r.score), ) ) # keep only up to max_solutions pmcd_min = pmcd_min[: max(1, int(self.cfg.max_solutions))] best = pmcd_min[0] self._final_meta_its(best) return best, pmcd_min, full_pool def _time_exceeded(self, t0: float) -> bool: lim = self.cfg.time_limit_s if lim is None or lim == 0: return False return (time.time() - t0) >= float(lim) # -------- candidate stream (exploration only; PMCD ignores candidate priors) -------- def _candidate_stream(self) -> Iterator[EdgeMaskCandidate]: if not self.cfg.enable_bond_cut or self.cfg.max_bond_cut_size <= 0: return iter(()) # type: ignore[return-value] return self._bond_cut_candidates() def _bond_cut_candidates(self) -> Iterator[EdgeMaskCandidate]: if ( self._r_cache is None or self._p_cache is None or self._rG is None or self._pG is None ): return iter(()) # type: ignore[return-value] # edges incident to "likely RC" from baseline *graph-only* diff (cheap, no heuristic scoring) # here we just use a quick baseline mapping (already computed WL in prime), but we do not rely on heuristic. base_map = self._map_wl(None, None, None) rc_r, rc_p = self._diff_rc_nodes(base_map) edges_r = self._incident_edges_unique(self._r_cache, rc_r) edges_p = self._incident_edges_unique(self._p_cache, rc_p) # enrich with global priority edges so we can reach correct acyl/peroxy patterns edges_r = self._merge_unique( edges_r, self._global_priority_edges(self._r_cache, limit=48) ) edges_p = self._merge_unique( edges_p, self._global_priority_edges(self._p_cache, limit=48) ) edges_r = self._prioritize_edges(self._r_cache, edges_r) edges_p = self._prioritize_edges(self._p_cache, edges_p) if self.cfg.enable_symmetry_pruning and self._r_class and self._p_class: edges_r = self._sym.unique_edges( edges_r, self._r_class, lambda u, v: self._r_cache.edge_orders(u, v) ) edges_p = self._sym.unique_edges( edges_p, self._p_class, lambda u, v: self._p_cache.edge_orders(u, v) ) pool_r = edges_r[: int(self.cfg.candidate_edge_pool)] pool_p = edges_p[: int(self.cfg.candidate_edge_pool)] # Stage A: up to max_bond_cut_size yield from self._emit_cut_combos( "reactant", pool_r, max_k=int(self.cfg.max_bond_cut_size), tag="bond_cut" ) yield from self._emit_cut_combos( "product", pool_p, max_k=int(self.cfg.max_bond_cut_size), tag="bond_cut" ) # Stage B: optional deeper masks (still PMCD-first; used only to discover additional PMCD-min solutions) if ( self.cfg.enable_heuristic_scoring and self.cfg.heuristic_max_cut_size > self.cfg.max_bond_cut_size ): hk = int(self.cfg.heuristic_max_cut_size) budget = int(self.cfg.heuristic_candidate_budget) yield from self._emit_cut_combos( "reactant", pool_r, max_k=hk, tag="deep_cut", combo_cap=budget ) yield from self._emit_cut_combos( "product", pool_p, max_k=hk, tag="deep_cut", combo_cap=budget ) @staticmethod def _merge_unique( a: List[Tuple[Hashable, Hashable]], b: List[Tuple[Hashable, Hashable]] ) -> List[Tuple[Hashable, Hashable]]: seen = set(_norm_pair(x, y) for x, y in a) out = list(a) for x, y in b: p = _norm_pair(x, y) if p in seen: continue seen.add(p) out.append(p) return out def _emit_cut_combos( self, side: str, pool: List[Tuple[Hashable, Hashable]], *, max_k: int, tag: str, combo_cap: int = 200, ) -> Iterator[EdgeMaskCandidate]: max_k = max(1, int(max_k)) combo_cap = max(1, int(combo_cap)) for k in range(1, max_k + 1): for subset in itertools.islice( itertools.combinations(pool, k), 0, combo_cap ): removed = frozenset(_norm_pair(a, b) for a, b in subset) meta = { "candidate": tag, "branch": tag, "side": side, "k": k, "cut_edges": list(subset), } yield EdgeMaskCandidate( side=side, removed_pairs=removed, prior_score=0.0, meta=meta ) # -------- edge ranking helpers (search ordering only) -------- def _xo_nodes( self, cache: GraphCache, u: Hashable, v: Hashable ) -> Tuple[Optional[Hashable], Optional[Hashable]]: eu = cache.element(u) ev = cache.element(v) if eu == "O" and ev in {"C", "Si"}: return v, u if ev == "O" and eu in {"C", "Si"}: return u, v return None, None def _is_peroxy_oo(self, cache: GraphCache, u: Hashable, v: Hashable) -> bool: if cache.element(u) != "O" or cache.element(v) != "O": return False return cache.is_acyl_oxygen_peroxy(u) or cache.is_acyl_oxygen_peroxy(v) def _edge_rank(self, cache: GraphCache, u: Hashable, v: Hashable) -> int: # peroxy O-O: very important to include early if cache.element(u) == "O" and cache.element(v) == "O": if self._is_peroxy_oo(cache, u, v): return -5 return 95 x, o = self._xo_nodes(cache, u, v) if x is None or o is None: return 90 if cache.element(x) == "C" and cache.is_carbonyl_c(x): if cache.edge_is_single(x, o): if cache.is_acyl_oxygen_peroxy(o): return 25 # keep available, but not before peroxy O-O return 0 return 999 # C=O if cache.element(x) == "C" and cache.aromatic(x): return 80 deg = cache.degree(x) if deg >= 3: return 10 if deg == 2: return 20 return 30 def _global_priority_edges( self, cache: GraphCache, limit: int = 48 ) -> List[Tuple[Hashable, Hashable]]: out: List[Tuple[Hashable, Hashable]] = [] for u, v in cache.G.edges(): eu = cache.element(u) ev = cache.element(v) keep = False if (eu == "O" and ev in {"C", "Si"}) or (ev == "O" and eu in {"C", "Si"}): keep = True elif eu == "O" and ev == "O" and self._is_peroxy_oo(cache, u, v): keep = True if not keep: continue out.append(_norm_pair(u, v)) out = list(dict.fromkeys(out)) out.sort( key=lambda e: (self._edge_rank(cache, e[0], e[1]), _stable_sort_key(e)) ) return out[: max(0, int(limit))] def _prioritize_edges( self, cache: GraphCache, edges: List[Tuple[Hashable, Hashable]] ) -> List[Tuple[Hashable, Hashable]]: def bond_strength_sum(orders: Tuple[Any, ...]) -> float: s = 0.0 for x in orders: if isinstance(x, (int, float)): s += float(x) elif isinstance(x, str): try: s += float(x) except Exception: s += 0.0 return s def key(e: Tuple[Hashable, Hashable]) -> Tuple[int, float, str]: u, v = e rank = self._edge_rank(cache, u, v) s = bond_strength_sum(cache.edge_orders(u, v)) return (rank, -s, _stable_sort_key((u, v))) return sorted(edges, key=key) def _incident_edges_unique( self, cache: GraphCache, nodes: List[Hashable] ) -> List[Tuple[Hashable, Hashable]]: seen: set = set() out: List[Tuple[Hashable, Hashable]] = [] for u in nodes: if u not in cache.node_set: continue for v in cache.neighbors(u): a, b = _norm_pair(u, v) if (a, b) in seen: continue seen.add((a, b)) out.append((a, b)) return out # -------- mapping core -------- def _map_two_pass( self, cand: Optional[EdgeMaskCandidate] ) -> Dict[Hashable, Hashable]: base = self._map_wl(cand, None, None) if not base: return {} chosen = base if self.cfg.enable_rc_refine: rc_r, rc_p = self._diff_rc_nodes(base) dist_r = self._dist_to_set(self._rG, rc_r) dist_p = self._dist_to_set(self._pG, rc_p) ref = self._map_wl(cand, dist_r, dist_p) if ref and self._pmcd_key(ref) < self._pmcd_key(base): chosen = ref if self.cfg.enable_swap_refine: chosen = self._swap_refine_mapping(chosen) return chosen def _candidate_masks( self, cand: Optional[EdgeMaskCandidate] ) -> Tuple[Optional[frozenset], Optional[frozenset]]: if cand is None: return None, None if cand.side == "reactant": return cand.removed_pairs, None return None, cand.removed_pairs def _masked_or_base_wl( self, view: MaskView, removed: Optional[frozenset], base_wl: Dict[Hashable, List[bytes]], init_labels: Dict[Hashable, bytes], ) -> Dict[Hashable, List[bytes]]: if removed is None: return base_wl if not self.cfg.enable_dynamic_wl: return self._hasher.full_hashes(view, init_labels, self._blake_bytes) hmap, _ = self._hasher.masked_hashes( view, init_labels, base_wl, removed, self._blake_bytes ) return hmap def _depth_schedule(self, max_depth: int) -> List[int]: if self.cfg.progressive_fallback: return list(range(max_depth, 0, -1)) return [max_depth] def _wl_key( self, hmap: Dict[Hashable, List[bytes]], u: Hashable, depth: int ) -> Optional[bytes]: seq = hmap.get(u) if not seq: return None if depth <= 0 or depth > len(seq): return None return seq[depth - 1] def _build_buckets( self, nodes: Iterable[Hashable], hmap: Dict[Hashable, List[bytes]], depth: int ) -> Dict[bytes, List[Hashable]]: buckets: Dict[bytes, List[Hashable]] = {} for u in nodes: k = self._wl_key(hmap, u, depth) if k is None: continue buckets.setdefault(k, []).append(u) return buckets def _map_wl( self, cand: Optional[EdgeMaskCandidate], rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> Dict[Hashable, Hashable]: if ( self._r_cache is None or self._p_cache is None or self._rG is None or self._pG is None ): return {} rm, pm = self._candidate_masks(cand) rv = MaskView(self._r_cache, rm) pv = MaskView(self._p_cache, pm) r_hash = self._masked_or_base_wl(rv, rm, self._r_wl, self._r_init) p_hash = self._masked_or_base_wl(pv, pm, self._p_wl, self._p_init) if not r_hash or not p_hash: return {} r_un = set(self._r_cache.nodes) p_un = set(self._p_cache.nodes) mapping: Dict[Hashable, Hashable] = {} max_depth = len(next(iter(r_hash.values()))) for depth in self._depth_schedule(max_depth): r_b = self._build_buckets(r_un, r_hash, depth) p_b = self._build_buckets(p_un, p_hash, depth) for key, ru in r_b.items(): pv_nodes = p_b.get(key) if not pv_nodes: continue self._match_bucket( rv, pv, ru, pv_nodes, mapping, r_un, p_un, rc_dist_r, rc_dist_p ) return mapping def _use_hungarian(self, n: int, m: int) -> bool: lim = int(self.cfg.hungarian_max_size) return n <= lim and m <= lim def _match_bucket( self, rv: MaskView, pv: MaskView, ru: List[Hashable], pv_nodes: List[Hashable], mapping: Dict[Hashable, Hashable], r_un: Set[Hashable], p_un: Set[Hashable], rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> None: if len(ru) == 1 and len(pv_nodes) == 1: u = ru[0] v = pv_nodes[0] if self._compatible(rv.cache, u, pv.cache, v): mapping[u] = v r_un.discard(u) p_un.discard(v) return ru_s = sorted(ru, key=_stable_sort_key) pv_s = sorted(pv_nodes, key=_stable_sort_key) if self._use_hungarian(len(ru_s), len(pv_s)): cost = self._cost_matrix(rv, pv, ru_s, pv_s, mapping, rc_dist_r, rc_dist_p) pairs = self._hungarian_min_cost_assignment(cost) self._accept_pairs_matrix(ru_s, pv_s, cost, pairs, r_un, p_un, mapping) return pairs = self._greedy_pairs_topk( rv, pv, ru_s, pv_s, mapping, rc_dist_r, rc_dist_p ) for i, j in pairs: u = ru_s[i] v = pv_s[j] mapping[u] = v r_un.discard(u) p_un.discard(v) def _mapped_neighbors_per_u( self, rv: MaskView, ru_s: List[Hashable], mapping: Dict[Hashable, Hashable] ) -> Dict[Hashable, List[Hashable]]: out: Dict[Hashable, List[Hashable]] = {} for u in ru_s: xs = [x for x in rv.neighbors(u) if x in mapping] if xs: out[u] = xs return out def _rc_distance_penalty( self, u: Hashable, v: Hashable, rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> float: if rc_dist_r is None or rc_dist_p is None: return 0.0 du = float(rc_dist_r.get(u, 10_000)) dv = float(rc_dist_p.get(v, 10_000)) return float(self.cfg.rc_distance_weight) * abs(du - dv) @staticmethod def _compatible( r_cache: GraphCache, u: Hashable, p_cache: GraphCache, v: Hashable ) -> bool: return r_cache.element(u) == p_cache.element(v) def _local_distance( self, rv: MaskView, pv: MaskView, u: Hashable, v: Hashable ) -> float: rC = rv.cache pC = pv.cache if rC.element(u) != pC.element(v): return math.inf base = 0.0 base += 2.0 if rC.aromatic(u) != pC.aromatic(v) else 0.0 base += abs(rC.hcount(u) - pC.hcount(v)) * 1.0 base += abs(rC.charge(u) - pC.charge(v)) * 2.0 base += abs(rv.degree(u) - pv.degree(v)) * 0.5 base += float(_multiset_l1_sorted(rv.neigh_codes(u), pv.neigh_codes(v))) return base def _neighbor_consistency_penalty( self, rv: MaskView, pv: MaskView, u: Hashable, v: Hashable, mapping: Dict[Hashable, Hashable], mapped_neigh_u: Optional[List[Hashable]], neigh_set_v: Set[Hashable], ) -> float: if not mapped_neigh_u: return 0.0 miss = 0.0 for x in mapped_neigh_u: y = mapping.get(x) if y is None: continue if y not in neigh_set_v: miss += 1.0 continue if rv.edge_label_one(u, x) != pv.edge_label_one(v, y): miss += 0.5 return miss def _pair_cost( self, rv: MaskView, pv: MaskView, u: Hashable, v: Hashable, mapping: Dict[Hashable, Hashable], mapped_neigh_u: Optional[List[Hashable]], neigh_set_v: Set[Hashable], rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> float: if not self._compatible(rv.cache, u, pv.cache, v): return math.inf base = self._local_distance(rv, pv, u, v) base += self._neighbor_consistency_penalty( rv, pv, u, v, mapping, mapped_neigh_u, neigh_set_v ) base += self._rc_distance_penalty(u, v, rc_dist_r, rc_dist_p) return base def _cost_matrix( self, rv: MaskView, pv: MaskView, ru_s: List[Hashable], pv_s: List[Hashable], mapping: Dict[Hashable, Hashable], rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> List[List[float]]: mapped_neigh = self._mapped_neighbors_per_u(rv, ru_s, mapping) pv_neigh = {v: pv.neighbors_set(v) for v in pv_s} out: List[List[float]] = [] for u in ru_s: row: List[float] = [] mu = mapped_neigh.get(u) for v in pv_s: row.append( self._pair_cost( rv, pv, u, v, mapping, mu, pv_neigh[v], rc_dist_r, rc_dist_p ) ) out.append(row) return out def _greedy_pairs_topk( self, rv: MaskView, pv: MaskView, ru_s: List[Hashable], pv_s: List[Hashable], mapping: Dict[Hashable, Hashable], rc_dist_r: Optional[Dict[Hashable, int]], rc_dist_p: Optional[Dict[Hashable, int]], ) -> List[Tuple[int, int]]: mapped_neigh = self._mapped_neighbors_per_u(rv, ru_s, mapping) pv_neigh = {v: pv.neighbors_set(v) for v in pv_s} topk = int(self.cfg.greedy_topk_per_u) items: List[Tuple[float, int, int]] = [] for i, u in enumerate(ru_s): mu = mapped_neigh.get(u) best: List[Tuple[float, int]] = [] for j, v in enumerate(pv_s): c = self._pair_cost( rv, pv, u, v, mapping, mu, pv_neigh[v], rc_dist_r, rc_dist_p ) if math.isinf(c) or c > 1e8: continue best.append((float(c), j)) if not best: continue best.sort(key=lambda x: x[0]) for c, j in best[:topk]: items.append((c, i, j)) items.sort(key=lambda x: x[0]) used_i: set = set() used_j: set = set() out: List[Tuple[int, int]] = [] for _, i, j in items: if i in used_i or j in used_j: continue used_i.add(i) used_j.add(j) out.append((i, j)) return out @staticmethod def _hungarian_min_cost_assignment( cost: List[List[float]], ) -> List[Tuple[int, int]]: n = len(cost) m = len(cost[0]) if n else 0 if n == 0 or m == 0: return [] N = max(n, m) BIG = 1e9 a = [[BIG] * N for _ in range(N)] for i in range(n): for j in range(m): c = cost[i][j] a[i][j] = BIG if math.isinf(c) else float(c) u = [0.0] * (N + 1) v = [0.0] * (N + 1) p = [0] * (N + 1) way = [0] * (N + 1) for i in range(1, N + 1): p[0] = i j0 = 0 minv = [BIG] * (N + 1) used = [False] * (N + 1) while True: used[j0] = True i0 = p[j0] delta = BIG j1 = 0 for j in range(1, N + 1): if used[j]: continue cur = a[i0 - 1][j - 1] - u[i0] - v[j] if cur < minv[j]: minv[j] = cur way[j] = j0 if minv[j] < delta: delta = minv[j] j1 = j for j in range(0, N + 1): if used[j]: u[p[j]] += delta v[j] -= delta else: minv[j] -= delta j0 = j1 if p[j0] == 0: break while True: j1 = way[j0] p[j0] = p[j1] j0 = j1 if j0 == 0: break ans: List[Tuple[int, int]] = [] for j in range(1, N + 1): i = p[j] if 1 <= i <= n and 1 <= j <= m: ans.append((i - 1, j - 1)) return ans @staticmethod def _accept_pairs_matrix( ru_s: List[Hashable], pv_s: List[Hashable], cost: List[List[float]], pairs: List[Tuple[int, int]], r_un: Set[Hashable], p_un: Set[Hashable], mapping: Dict[Hashable, Hashable], ) -> None: for i, j in pairs: c = cost[i][j] if math.isinf(c) or c > 1e8: continue u = ru_s[i] v = pv_s[j] mapping[u] = v r_un.discard(u) p_un.discard(v) # -------- swap refine (unchanged) -------- def _swap_refine_mapping( self, mapping: Dict[Hashable, Hashable] ) -> Dict[Hashable, Hashable]: if not mapping: return mapping groups = self._swap_groups(mapping) if not groups: return mapping cur = dict(mapping) best = self._heuristic_bond_cost(cur) # heuristic tie-break, not PMCD for _ in range(int(self.cfg.swap_refine_max_iter)): improved = False for nodes in groups: cur2, best2, changed = self._swap_refine_group(cur, best, nodes) if changed: cur, best = cur2, best2 improved = True if not improved: break return cur def _swap_groups(self, mapping: Dict[Hashable, Hashable]) -> List[List[Hashable]]: depth = int(self.cfg.swap_refine_class_depth) max_sz = int(self.cfg.swap_refine_max_group_size) buckets: Dict[Tuple[Tuple[bytes, ...], Tuple[bytes, ...]], List[Hashable]] = {} for u, v in mapping.items(): rk = self._wl_prefix(self._r_wl.get(u, []), depth) pk = self._wl_prefix(self._p_wl.get(v, []), depth) if rk is None or pk is None: continue buckets.setdefault((rk, pk), []).append(u) out: List[List[Hashable]] = [] for nodes in buckets.values(): if 2 <= len(nodes) <= max_sz: out.append(sorted(nodes, key=_stable_sort_key)) return out @staticmethod def _wl_prefix(seq: List[bytes], depth: int) -> Optional[Tuple[bytes, ...]]: if not seq: return None d = min(int(depth), len(seq)) if d <= 0: return None return tuple(seq[:d]) def _swap_refine_group( self, mapping: Dict[Hashable, Hashable], best_score: float, nodes: List[Hashable], ) -> Tuple[Dict[Hashable, Hashable], float, bool]: cur = dict(mapping) best = float(best_score) for i in range(len(nodes)): for j in range(i + 1, len(nodes)): u1 = nodes[i] u2 = nodes[j] if u1 not in cur or u2 not in cur: continue v1 = cur[u1] v2 = cur[u2] if v1 == v2: continue trial = dict(cur) trial[u1] = v2 trial[u2] = v1 sc = self._heuristic_bond_cost(trial) if sc < best: return trial, sc, True return cur, best, False # -------- PMCD key + scoring -------- def _pmcd_key(self, mapping: Dict[Hashable, Hashable]) -> Tuple[int, int, int]: """ PMCD key: (unmapped_atoms, bond_changes_count, hcount_changes_count) This is the ONLY criterion used in stage-1 selection. """ if self._r_cache is None or self._p_cache is None: return (10**9, 10**9, 10**9) ur, up = self._unmapped_counts(mapping) if not mapping: return (ur + up, 10**9, 10**9) # RC restriction (optional) rc_r, rc_p = self._diff_rc_nodes(mapping) rc_r = self._expand_nodes(self._r_cache, rc_r, int(self.cfg.rc_expand_hops)) rc_p = self._expand_nodes(self._p_cache, rc_p, int(self.cfg.rc_expand_hops)) if self.cfg.rc_only_bond_changes: sub = {u: mapping[u] for u in rc_r if u in mapping} bond_changes = self._bond_change_count(sub) else: bond_changes = self._bond_change_count(mapping) if self.cfg.rc_only_hcount_changes: hcount_changes = self._hcount_change_count(mapping, rc_r) else: hcount_changes = self._hcount_change_count(mapping, None) return (int(ur + up), int(bond_changes), int(hcount_changes)) def _pmcd_numeric(self, key: Tuple[int, int, int]) -> float: # numeric proxy; comparisons MUST use the tuple key, not this number # large bases avoid collisions for typical sizes u, b, h = key return float(u) * 1e6 + float(b) * 1e3 + float(h) def _score_pmcd_and_heuristic( self, mapping: Dict[Hashable, Hashable], meta: Dict[str, Any] ) -> MappingResult: pmcd_key = self._pmcd_key(mapping) pmcd_score = self._pmcd_numeric(pmcd_key) heur_cost = self._heuristic_cost(mapping, meta) out = MappingResult(mapping=mapping, score=float(pmcd_score), meta=dict(meta)) out.meta["pmcd_key"] = tuple(pmcd_key) out.meta["pmcd_score"] = float(pmcd_score) out.meta["heuristic_cost"] = float(heur_cost) out.meta["mapped_pairs"] = int(len(mapping)) out.meta["unmapped_atoms"] = int(pmcd_key[0]) out.meta["bond_changes"] = int(pmcd_key[1]) out.meta["hcount_changes"] = int(pmcd_key[2]) return out # -------- PMCD components -------- def _bond_change_count(self, mapping: Dict[Hashable, Hashable]) -> int: """ Unweighted chemical distance: count of bond edits (removed/created/order mismatch) = 1 each. """ if ( self._r_cache is None or self._p_cache is None or self._rG is None or self._pG is None ): return 10**9 inv = {pv: ru for ru, pv in mapping.items()} mapped_r = set(mapping.keys()) mapped_p = set(mapping.values()) seen_p: set = set() changes = 0 # edges in reactants among mapped nodes -> check product for u, v in self._iter_edges_between(self._rG, mapped_r): pu = mapping.get(u) pv = mapping.get(v) if pu is None or pv is None: continue a, b = _norm_pair(pu, pv) if (a, b) in seen_p: continue seen_p.add((a, b)) ro = self._r_cache.edge_orders(u, v) po = self._p_cache.edge_orders(pu, pv) if not po: changes += 1 elif ro != po: changes += 1 # edges created in product among mapped nodes -> check reactant for u, v in self._iter_edges_between(self._pG, mapped_p): ru = inv.get(u) rv = inv.get(v) if ru is None or rv is None: continue if not self._r_cache.edge_orders(ru, rv): changes += 1 return int(changes) def _hcount_change_count( self, mapping: Dict[Hashable, Hashable], nodes: Optional[List[Hashable]] ) -> int: if not mapping or self._r_cache is None or self._p_cache is None: return 0 it = mapping.keys() if nodes is None else (u for u in nodes if u in mapping) total = 0 for u in it: v = mapping.get(u) if v is None: continue total += abs(self._r_cache.hcount(u) - self._p_cache.hcount(v)) return int(total) # -------- heuristic tie-break (ONLY stage 2) -------- def _heuristic_cost( self, mapping: Dict[Hashable, Hashable], meta: Dict[str, Any] ) -> float: """ Tie-break ONLY: - type-weighted bond change cost (handles peracid: prefers O-O cleavage) - + small candidate cut preference penalty/bonus (based on cut_edges) """ if not mapping: return 1e9 bond_cost = self._heuristic_bond_cost(mapping) cut_pen = 0.0 side = meta.get("side") cut_edges = meta.get("cut_edges") or [] if ( isinstance(cut_edges, list) and cut_edges and (side in {"reactant", "product"}) ): cache = self._r_cache if side == "reactant" else self._p_cache if cache is not None: cut_pen = self._heuristic_cut_penalty(cache, cut_edges) # prefer fewer cuts slightly (but only after PMCD) k = float(meta.get("k", 0.0) or 0.0) return float(bond_cost) + float(cut_pen) + 0.02 * k def _heuristic_cut_penalty( self, cache: GraphCache, cut_edges: List[Tuple[Hashable, Hashable]] ) -> float: """ If peroxyacyl present, strongly prefer cutting O-O (outer oxygen bond), and strongly avoid cutting acyl C-O in C(=O)-O-O. """ pen = 0.0 for u, v in cut_edges: eu = cache.element(u) ev = cache.element(v) # bonus: cutting peroxy O-O if eu == "O" and ev == "O" and self._is_peroxy_oo(cache, u, v): pen -= 1.5 continue x, o = self._xo_nodes(cache, u, v) if x is None or o is None: continue # penalty: cutting acyl C-O when that oxygen is peroxyacyl oxygen if ( cache.element(x) == "C" and cache.is_carbonyl_c(x) and cache.is_acyl_oxygen_peroxy(o) ): pen += 2.0 return float(pen) def _heuristic_bond_cost(self, mapping: Dict[Hashable, Hashable]) -> float: cost, _ = self._bond_change_cost_and_report(mapping, want_report=False) return float(cost) def _bond_change_cost_and_report( self, mapping: Dict[Hashable, Hashable], want_report: bool ) -> Tuple[float, List[Dict[str, Any]]]: """ Type-weighted cost used ONLY in heuristic tie-break. """ if ( self._r_cache is None or self._p_cache is None or self._rG is None or self._pG is None ): return 0.0, [] inv = {pv: ru for ru, pv in mapping.items()} mapped_r = set(mapping.keys()) mapped_p = set(mapping.values()) seen_p: set = set() total_cost = 0.0 report: List[Dict[str, Any]] = [] for u, v in self._iter_edges_between(self._rG, mapped_r): pu = mapping.get(u) pv = mapping.get(v) if pu is None or pv is None: continue a, b = _norm_pair(pu, pv) if (a, b) in seen_p: continue seen_p.add((a, b)) ro = self._r_cache.edge_orders(u, v) po = self._p_cache.edge_orders(pu, pv) if not po: c = self._bond_type_cost(self._r_cache, u, v) total_cost += c if want_report: report.append( { "kind": "removed_in_product", "r_edge": (u, v), "p_edge": (pu, pv), "r_order": ro, "p_order": po, "cost": float(c), } ) continue if ro != po: cr = self._bond_type_cost(self._r_cache, u, v) cp = self._bond_type_cost(self._p_cache, pu, pv) c = float(self.cfg.bc_cost_order_mismatch_scale) * 0.5 * (cr + cp) total_cost += c if want_report: report.append( { "kind": "order_mismatch", "r_edge": (u, v), "p_edge": (pu, pv), "r_order": ro, "p_order": po, "cost": float(c), } ) continue for u, v in self._iter_edges_between(self._pG, mapped_p): ru = inv.get(u) rv = inv.get(v) if ru is None or rv is None: continue ro = self._r_cache.edge_orders(ru, rv) if not ro: c = self._bond_type_cost(self._p_cache, u, v) total_cost += c if want_report: report.append( { "kind": "created_in_product", "r_edge": (ru, rv), "p_edge": (u, v), "r_order": ro, "p_order": self._p_cache.edge_orders(u, v), "cost": float(c), } ) return float(total_cost), report def _bond_type_cost(self, cache: GraphCache, u: Hashable, v: Hashable) -> float: # peroxy O-O is cheap to change (preferred cleavage for peracids) if cache.element(u) == "O" and cache.element(v) == "O": if self._is_peroxy_oo(cache, u, v): return float(self.cfg.bc_cost_peroxy_oo) return float(self.cfg.bc_cost_other) x, o = self._xo_nodes(cache, u, v) if x is not None and o is not None: if cache.element(x) == "C" and cache.is_carbonyl_c(x): if cache.edge_is_double(x, o): return float(self.cfg.heuristic_carbonyl_double_penalty) if cache.is_acyl_oxygen_peroxy(o): return float(self.cfg.bc_cost_acyl_co_peroxy) return float(self.cfg.bc_cost_acyl_co) if cache.element(x) == "C" and cache.aromatic(x): return float(self.cfg.bc_cost_aromatic_co) deg = cache.degree(x) if deg >= 3: return float(self.cfg.bc_cost_x_deg3_o) if deg == 2: return float(self.cfg.bc_cost_x_deg2_o) return float(self.cfg.bc_cost_x_deg1_o) if (cache.element(u) == "C" and cache.aromatic(u)) or ( cache.element(v) == "C" and cache.aromatic(v) ): return float(max(self.cfg.bc_cost_other, self.cfg.bc_cost_aromatic_co)) return float(self.cfg.bc_cost_other) # -------- RC detection + utilities -------- @staticmethod def _iter_edges_between( G: Any, nodes: Set[Hashable] ) -> Iterable[Tuple[Hashable, Hashable]]: if G is None: return () for u, v in G.edges(nodes): if u in nodes and v in nodes: yield u, v def _unmapped_counts(self, mapping: Dict[Hashable, Hashable]) -> Tuple[int, int]: if self._r_cache is None or self._p_cache is None: return 0, 0 mr = set(mapping.keys()) mp = set(mapping.values()) ur = sum(1 for u in self._r_cache.nodes if u not in mr) up = sum(1 for v in self._p_cache.nodes if v not in mp) return int(ur), int(up) def _diff_rc_nodes( self, mapping: Dict[Hashable, Hashable] ) -> Tuple[List[Hashable], List[Hashable]]: if not mapping: return [], [] changed_r = self._changed_endpoints_r(mapping) rc_r = sorted(changed_r, key=repr) rc_p = [mapping[u] for u in rc_r if u in mapping] return rc_r, rc_p def _changed_endpoints_r(self, mapping: Dict[Hashable, Hashable]) -> Set[Hashable]: if ( self._r_cache is None or self._p_cache is None or self._rG is None or self._pG is None ): return set() inv = {pv: ru for ru, pv in mapping.items()} mapped_r = set(mapping.keys()) mapped_p = set(mapping.values()) out: Set[Hashable] = set() for u, v in self._iter_edges_between(self._rG, mapped_r): pu = mapping.get(u) pv = mapping.get(v) if pu is None or pv is None: continue po = self._p_cache.edge_orders(pu, pv) ro = self._r_cache.edge_orders(u, v) if not po or ro != po: out.add(u) out.add(v) for u, v in self._iter_edges_between(self._pG, mapped_p): ru = inv.get(u) rv = inv.get(v) if ru is None or rv is None: continue if not self._r_cache.edge_orders(ru, rv): out.add(ru) out.add(rv) return out def _expand_nodes( self, cache: GraphCache, nodes: List[Hashable], hops: int ) -> List[Hashable]: if hops <= 0 or not nodes: return list(dict.fromkeys(nodes)) seen = set(nodes) frontier = set(nodes) for _ in range(hops): nxt: Set[Hashable] = set() for u in frontier: for v in cache.neighbors(u): if v not in seen: seen.add(v) nxt.add(v) frontier = nxt if not frontier: break return sorted(seen, key=repr) def _dist_to_set(self, G: Any, sources: List[Hashable]) -> Dict[Hashable, int]: if not sources or G is None: return {} try: import networkx as nx except Exception: return {} UG = G.to_undirected() if hasattr(G, "to_undirected") else G try: return dict(nx.multi_source_shortest_path_length(UG, sources)) except Exception: dist: Dict[Hashable, int] = {} for s in sources: if s not in UG: continue d = nx.single_source_shortest_path_length(UG, s) for n, ln in d.items(): prev = dist.get(n) dist[n] = ln if prev is None else min(prev, ln) return dist # -------- ITS meta (optional) -------- def _final_meta_its(self, best: MappingResult) -> None: if not self.cfg.use_its_final: return if ITSConstruction is None or get_rc is None: return if not best.mapping or self._rG is None or self._pG is None: return try: rr, pp = self._rG.copy(), self._pG.copy() self._write_temp_atom_maps(rr, pp, best.mapping) its = ITSConstruction().fit(rr, pp).its # type: ignore[attr-defined] rc = get_rc(its) # type: ignore[misc] best.meta["its_rc_nodes"] = int(len(list(rc.nodes()))) best.meta["its_rc_edges"] = int(len(list(rc.edges()))) except Exception: best.meta["its_rc_nodes"] = None best.meta["its_rc_edges"] = None # -------- materialization -------- def _materialize_solutions(self, pool: List[MappingResult]) -> List[Solution]: out: List[Solution] = [] seen: set = set() for res in pool: rsmi = self._materialize_one(res) if rsmi in seen: continue seen.add(rsmi) out.append(Solution(result=res, mapped_rsmi=rsmi)) return out def _materialize_one(self, res: MappingResult) -> str: if self._rG is None or self._pG is None: raise RuntimeError("missing graphs") rr, pp = self._rG.copy(), self._pG.copy() self._apply_atom_map_numbers(rr, pp, res.mapping) mapped = self._graphs_to_rsmi(rr, pp) return self._canon.canonicalise(mapped) def _graphs_to_rsmi(self, rG: Any, pG: Any) -> str: r_smi = self._graph_to_smi_strict(rG) p_smi = self._graph_to_smi_strict(pG) return f"{r_smi}>>{p_smi}" def _graph_to_smi_strict(self, G: Any) -> str: for kwargs in ( {"canonical": False, "use_atom_map": True}, {"canonical": False, "atom_map_key": "atom_map"}, {"use_atom_map": True}, {"atom_map_key": "atom_map"}, {"canonical": False}, {}, ): try: s = graph_to_smi(G, **kwargs) except Exception: s = None if isinstance(s, str) and s.strip(): return s raise ValueError("graph_to_smi failed") def _apply_atom_map_numbers( self, rG: Any, pG: Any, mapping: Dict[Hashable, Hashable] ) -> None: for u in rG.nodes: rG.nodes[u]["atom_map"] = int(self.cfg.unmapped_value) for v in pG.nodes: pG.nodes[v]["atom_map"] = int(self.cfg.unmapped_value) k = int(self.cfg.start_atom_map) used_r: Set[Hashable] = set() used_p: Set[Hashable] = set() for u in sorted(mapping.keys(), key=_stable_sort_key): v = mapping[u] rG.nodes[u]["atom_map"] = k pG.nodes[v]["atom_map"] = k used_r.add(u) used_p.add(v) k += 1 if not self.cfg.assign_maps_to_unmapped: return for u in sorted((x for x in rG.nodes if x not in used_r), key=_stable_sort_key): rG.nodes[u]["atom_map"] = k k += 1 for v in sorted((x for x in pG.nodes if x not in used_p), key=_stable_sort_key): pG.nodes[v]["atom_map"] = k k += 1 @staticmethod def _write_temp_atom_maps( rr: Any, pp: Any, mapping: Dict[Hashable, Hashable] ) -> None: for u in rr.nodes: rr.nodes[u]["atom_map"] = 0 for v in pp.nodes: pp.nodes[v]["atom_map"] = 0 k = 1 for u in sorted(mapping.keys(), key=repr): v = mapping[u] rr.nodes[u]["atom_map"] = k pp.nodes[v]["atom_map"] = k k += 1 def _bond_change_report_for_mapping( self, mapping: Dict[Hashable, Hashable] ) -> List[Dict[str, Any]]: _, rep = self._bond_change_cost_and_report(mapping, want_report=True) return rep def _blake_bytes(self, b: bytes) -> bytes: h = hashlib.blake2b(digest_size=int(self.cfg.digest_size)) h.update(b) return h.digest()