from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple
import logging
from concurrent.futures import ProcessPoolExecutor
from itertools import combinations
import networkx as nx
from synkit.Synthesis.Reactor.syn_reactor import SynReactor
logger = logging.getLogger(__name__)
try:
from rdkit import Chem
except ImportError:
Chem = None
def _apply_rule_worker(
args: Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]],
) -> Tuple[int, Tuple[str, ...], List[str]]:
idx, rule, substrate, explicit_h, implicit_temp, strategy, reactant_keys = args
kwargs = dict(
smiles=substrate,
template=rule,
invert=False,
explicit_h=explicit_h,
implicit_temp=implicit_temp,
)
if strategy is not None:
kwargs["strategy"] = strategy
reactor = SynReactor.from_smiles(**kwargs)
return idx, reactant_keys, list(reactor.smiles_list)
def _count_lhs_components(text: str) -> Optional[int]:
if not text:
return None
lhs = text.split(">>", 1)[0].strip() if ">>" in text else text.strip()
parts = [p for p in lhs.split(".") if p.strip()]
return len(parts) if parts else None
def _canonicalize_nomap_rdkit(smiles: str) -> Optional[str]:
if Chem is None:
return smiles if smiles else None
if not smiles:
return None
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol is None:
return None
try:
Chem.SanitizeMol(mol)
except Exception:
return None
for a in mol.GetAtoms():
a.SetAtomMapNum(0)
mol = Chem.RemoveAllHs(mol)
return Chem.MolToSmiles(mol, canonical=True)
def _mol_from_smiles_safe(smiles: str) -> Optional[Any]:
"""Return RDKit Mol or None (no exceptions)."""
if Chem is None or not smiles:
return None
try:
return Chem.MolFromSmiles(smiles, sanitize=False)
except Exception:
return None
def _sanitize_safe(mol: Any) -> bool:
"""Sanitize mol in-place; return True on success."""
try:
Chem.SanitizeMol(mol)
return True
except Exception:
return False
def _has_atom_maps(mol: Any) -> bool:
"""Return True if any atom has a non-zero atom-map number."""
try:
return any(a.GetAtomMapNum() > 0 for a in mol.GetAtoms())
except Exception:
return False
def _strip_maps_and_canonical_from_mol(mol: Any) -> Optional[str]:
"""
Zero all atom-map numbers, remove Hs and return canonical SMILES.
Returns None on failure.
"""
try:
for a in mol.GetAtoms():
a.SetAtomMapNum(0)
mol_nomap = Chem.RemoveAllHs(mol)
return Chem.MolToSmiles(mol_nomap, canonical=True)
except Exception:
return None
def _canonical_from_mol(mol: Any) -> Optional[str]:
"""Return canonical SMILES for mol (may keep maps)."""
try:
return Chem.MolToSmiles(mol, canonical=True)
except Exception:
return None
def _assign_deterministic_maps_and_canonical(nomap_smiles: str) -> Optional[str]:
"""
Given a canonical nomap SMILES:
- reparse it,
- sanitize,
- assign atom-map = atom index + 1,
- return canonical SMILES for the mapped mol.
"""
try:
mol2 = Chem.MolFromSmiles(nomap_smiles, sanitize=False)
if mol2 is None:
return None
if not _sanitize_safe(mol2):
return None
for a in mol2.GetAtoms():
a.SetAtomMapNum(a.GetIdx() + 1)
return Chem.MolToSmiles(mol2, canonical=True)
except Exception:
return None
def _standardize_smiles_rdkit(smiles: str, *, keep_aam: bool) -> Optional[str]:
"""
Thin wrapper that sequences small helpers to keep cyclomatic complexity low.
Behavior:
- If Chem is None: return input SMILES (passthrough) or None for empty input.
- If keep_aam=False: return canonical SMILES with maps stripped.
- If keep_aam=True:
* if input has maps -> return canonical SMILES (keep maps)
* else deterministically assign maps and return canonical mapped SMILES
"""
if not smiles:
return None
if Chem is None:
# RDKit not available: passthrough (best-effort)
return smiles
mol = _mol_from_smiles_safe(smiles)
if mol is None:
return None
if not _sanitize_safe(mol):
return None
if not keep_aam:
return _strip_maps_and_canonical_from_mol(mol)
# keep_aam True
if _has_atom_maps(mol):
return _canonical_from_mol(mol)
# no maps: produce deterministic mapped canonical SMILES
nomap = _strip_maps_and_canonical_from_mol(mol)
if nomap is None:
return None
return _assign_deterministic_maps_and_canonical(nomap)
def _dedup_key(
*,
dedup_across_rules: bool,
rule_index: int,
r_keep_keys: Tuple[str, ...],
p_keep_keys: Tuple[str, ...],
) -> Tuple[Optional[int], Tuple[str, ...], Tuple[str, ...]]:
ridx: Optional[int] = None if dedup_across_rules else int(rule_index)
return (ridx, r_keep_keys, p_keep_keys)
def _sorted_smiles_from_ids(graph: nx.DiGraph, ids: List[int]) -> str:
return ".".join(sorted(graph.nodes[i]["smiles"] for i in ids))
def _iter_mixtures_arity1(
pool_keys: List[str],
frontier_keys: List[str],
*,
use_frontier: bool,
cap: int,
) -> Iterator[Tuple[str, ...]]:
pool_u = sorted(set(pool_keys))
if not pool_u:
return
frontier_u = sorted(set(frontier_keys))
if not use_frontier:
frontier_u = pool_u
n = 0
for f in frontier_u:
yield (f,)
n += 1
if n >= cap:
return
def _iter_mixtures_arity2(
pool_keys: List[str],
frontier_keys: List[str],
*,
use_frontier: bool,
cap: int,
) -> Iterator[Tuple[str, ...]]:
pool_u = sorted(set(pool_keys))
if not pool_u:
return
frontier_u = sorted(set(frontier_keys))
if not use_frontier:
frontier_u = pool_u
n = 0
for f in frontier_u:
for x in pool_u:
if x == f:
continue
yield (f, x) if f < x else (x, f)
n += 1
if n >= cap:
return
def _iter_mixtures_arityk(
pool_keys: List[str],
frontier_keys: List[str],
*,
use_frontier: bool,
arity: int,
cap: int,
) -> Iterator[Tuple[str, ...]]:
pool_u = sorted(set(pool_keys))
if not pool_u:
return
frontier_u = sorted(set(frontier_keys))
if not use_frontier:
frontier_u = pool_u
n = 0
for f in frontier_u:
others = [x for x in pool_u if x != f]
for comb in combinations(others, arity - 1):
yield tuple(sorted((f, *comb)))
n += 1
if n >= cap:
return
# --------------------------------------------------------------------------- #
# SynCRN
# --------------------------------------------------------------------------- #
[docs]
@dataclass
class SynCRN:
rules: List[Any]
repeats: int = 50
explicit_h: bool = False
implicit_temp: bool = False
strategy: Optional[str] = None
keep_aam: bool = True
max_components: int = 3
use_frontier: bool = True
max_mixtures_per_rule_step: int = 50_000
max_tasks_per_step: int = 200_000
skip_no_change: bool = True
allow_empty_side: bool = False
dedup_delta: bool = True
dedup_across_rules: bool = False
graph: nx.DiGraph = field(init=False)
_species_index: Dict[str, int] = field(init=False) # canonical nomap -> node_id
_next_node_id: int = field(init=False)
_smiles_cache: Dict[str, Optional[str]] = field(init=False)
_nomap_cache: Dict[str, Optional[str]] = field(init=False)
_app_counter: Dict[Tuple[int, int], int] = field(init=False)
_seen_attempts: Set[Tuple[int, Tuple[str, ...]]] = field(init=False)
_seen_delta: Set[Tuple[Optional[int], Tuple[str, ...], Tuple[str, ...]]] = field(
init=False
)
_rule_arity_cache: Dict[int, int] = field(init=False)
def __post_init__(self) -> None:
if self.max_components < 1:
raise ValueError("max_components must be >= 1")
self.graph = nx.DiGraph()
self._species_index = {}
self._next_node_id = 1
self._smiles_cache = {}
self._nomap_cache = {}
self._app_counter = {}
self._seen_attempts = set()
self._seen_delta = set()
self._rule_arity_cache = {}
# ------------------- IDs ------------------- #
def _alloc_node_id(self) -> int:
nid = self._next_node_id
self._next_node_id += 1
return nid
def _next_app_index(self, *, step: int, rule_index: int) -> int:
k = (step, rule_index)
cur = self._app_counter.get(k, 0)
self._app_counter[k] = cur + 1
return cur
# ------------------- SMILES processing ------------------- #
def _canonical_nomap(self, smiles: str) -> Optional[str]:
if smiles in self._nomap_cache:
return self._nomap_cache[smiles]
out = _canonicalize_nomap_rdkit(smiles)
self._nomap_cache[smiles] = out
return out
def _standardize_smiles(self, smiles: str) -> Optional[str]:
if smiles in self._smiles_cache:
return self._smiles_cache[smiles]
out = _standardize_smiles_rdkit(smiles, keep_aam=self.keep_aam)
self._smiles_cache[smiles] = out
if Chem is None and out is not None:
logger.warning(
"RDKit not available; SMILES standardization is a passthrough."
)
return out
# ------------------- Rule arity ------------------- #
def _infer_rule_arity(self, rule: Any, rule_index: int) -> int:
if rule_index in self._rule_arity_cache:
return self._rule_arity_cache[rule_index]
ar: Optional[int] = None
if isinstance(rule, str):
ar = _count_lhs_components(rule)
else:
ar = self._infer_rule_arity_from_attrs(rule)
if ar is None or ar < 1:
ar = 2
self._rule_arity_cache[rule_index] = int(ar)
return int(ar)
def _infer_rule_arity_from_attrs(self, rule: Any) -> Optional[int]:
for attr in ("smarts", "smirks", "template"):
if not hasattr(rule, attr):
continue
try:
ar = _count_lhs_components(str(getattr(rule, attr)))
if ar is not None:
return ar
except Exception:
continue
try:
return _count_lhs_components(repr(rule))
except Exception:
return None
# ------------------- Graph nodes ------------------- #
def _add_species_node(self, smiles: str) -> Optional[int]:
std = self._standardize_smiles(smiles)
if std is None:
return None
key = self._canonical_nomap(std)
if key is None:
return None
if key in self._species_index:
return self._species_index[key]
nid = self._alloc_node_id()
self._species_index[key] = nid
self.graph.add_node(
nid,
kind="species",
smiles=std,
smiles_nomap=key,
label=std,
)
return nid
def _add_rxn_event_node(self, *, step: int, rule_index: int) -> int:
rule = self.rules[rule_index]
rule_name = getattr(rule, "name", f"r{rule_index}")
app_index = self._next_app_index(step=step, rule_index=rule_index)
label = f"{rule_name}@{step}@{app_index}"
eid = self._alloc_node_id()
self.graph.add_node(
eid,
kind="rxn",
label=label,
step=step,
rule_index=rule_index,
rule_name=rule_name,
app_index=app_index,
rule_repr=repr(rule),
)
return eid
# ------------------- Delta / overlap ------------------- #
def _delta_keep_ids(
self,
reactant_ids: List[int],
products_raw: List[str],
) -> Tuple[List[int], List[int], Tuple[str, ...], Tuple[str, ...]]:
r_keys_all = [
self.graph.nodes[rid].get(
"smiles_nomap", self.graph.nodes[rid].get("smiles")
)
for rid in reactant_ids
]
r_set_all = set(r_keys_all)
p_ids_all: List[int] = []
p_keys_all: List[str] = []
for p in products_raw:
pid = self._add_species_node(p)
if pid is None:
continue
p_ids_all.append(pid)
p_keys_all.append(
self.graph.nodes[pid].get(
"smiles_nomap", self.graph.nodes[pid].get("smiles")
)
)
p_set_all = set(p_keys_all)
unchanged = r_set_all & p_set_all
r_keep_ids = [
rid for rid, rk in zip(reactant_ids, r_keys_all) if rk not in unchanged
]
p_keep_ids = [
pid for pid, pk in zip(p_ids_all, p_keys_all) if pk not in unchanged
]
r_keep_keys = tuple(
sorted(
self.graph.nodes[rid].get(
"smiles_nomap", self.graph.nodes[rid].get("smiles")
)
for rid in r_keep_ids
)
)
p_keep_keys = tuple(
sorted(
self.graph.nodes[pid].get(
"smiles_nomap", self.graph.nodes[pid].get("smiles")
)
for pid in p_keep_ids
)
)
return r_keep_ids, p_keep_ids, r_keep_keys, p_keep_keys
# ------------------- Mixture generation ------------------- #
def _iter_mixtures_for_rule(
self,
pool_keys: List[str],
frontier_keys: List[str],
*,
arity: int,
cap: int,
) -> Iterator[Tuple[str, ...]]:
if arity < 1 or arity > self.max_components:
return iter(())
if arity == 1:
return _iter_mixtures_arity1(
pool_keys,
frontier_keys,
use_frontier=self.use_frontier,
cap=cap,
)
if arity == 2:
return _iter_mixtures_arity2(
pool_keys,
frontier_keys,
use_frontier=self.use_frontier,
cap=cap,
)
return _iter_mixtures_arityk(
pool_keys,
frontier_keys,
use_frontier=self.use_frontier,
arity=arity,
cap=cap,
)
# ------------------- Build orchestration split to reduce C901 ------------------- #
[docs]
def build(
self,
seeds: Iterable[str],
*,
parallel: bool = False,
max_workers: Optional[int] = None,
) -> nx.DiGraph:
pool_keys, frontier_keys = self._init_pool(seeds)
if not pool_keys:
return self.graph
for step in range(1, self.repeats + 1):
if self.use_frontier and not frontier_keys:
break
tasks = self._make_tasks_for_step(pool_keys, frontier_keys, step)
if not tasks:
break
results = self._run_tasks(tasks, parallel=parallel, max_workers=max_workers)
next_frontier = self._integrate_results(results, pool_keys, step)
frontier_keys = next_frontier
if self.use_frontier and not frontier_keys:
break
return self.graph
def _init_pool(self, seeds: Iterable[str]) -> Tuple[Set[str], Set[str]]:
pool_keys: Set[str] = set()
frontier_keys: Set[str] = set()
for s in seeds:
sid = self._add_species_node(s)
if sid is None:
continue
k = self.graph.nodes[sid].get(
"smiles_nomap", self.graph.nodes[sid].get("smiles")
)
pool_keys.add(k)
frontier_keys.add(k)
return pool_keys, frontier_keys
def _make_tasks_for_step(
self,
pool_keys: Set[str],
frontier_keys: Set[str],
step: int,
) -> List[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]]:
budget = int(self.max_tasks_per_step)
tasks: List[
Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]
] = []
# Pre-cache id->smiles for the pool to avoid repeated graph lookups.
id_to_smiles = self._cache_id_to_smiles(pool_keys)
for ridx, rule in enumerate(self.rules):
if budget <= 0:
break
arity = self._infer_rule_arity(rule, ridx)
if arity > self.max_components:
continue
cap = min(int(self.max_mixtures_per_rule_step), budget)
mix_iter = self._iter_mixtures_for_rule(
list(pool_keys),
list(frontier_keys),
arity=arity,
cap=cap,
)
for mix_keys in mix_iter:
if budget <= 0:
break
task = self._task_from_mix(
ridx=ridx,
rule=rule,
mix_keys=mix_keys,
id_to_smiles=id_to_smiles,
)
if task is None:
continue
tasks.append(task)
budget -= 1
return tasks
def _cache_id_to_smiles(self, pool_keys: Set[str]) -> Dict[int, str]:
id_to_smiles: Dict[int, str] = {}
for k in pool_keys:
nid = self._species_index.get(k)
if nid is None:
continue
id_to_smiles[nid] = self.graph.nodes[nid]["smiles"]
return id_to_smiles
def _task_from_mix(
self,
*,
ridx: int,
rule: Any,
mix_keys: Tuple[str, ...],
id_to_smiles: Dict[int, str],
) -> Optional[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]]:
app_key = (int(ridx), mix_keys)
if app_key in self._seen_attempts:
return None
self._seen_attempts.add(app_key)
reactant_ids = [self._species_index.get(k) for k in mix_keys]
if any(nid is None for nid in reactant_ids):
return None
ids = [int(nid) for nid in reactant_ids if nid is not None]
# Build substrate from cached smiles
try:
substrate = ".".join(sorted(id_to_smiles[i] for i in ids))
except KeyError:
# A reactant not in current pool cache; fallback to graph lookup
substrate = _sorted_smiles_from_ids(self.graph, ids)
return (
int(ridx),
rule,
substrate,
self.explicit_h,
self.implicit_temp,
self.strategy,
mix_keys,
)
def _run_tasks(
self,
tasks: List[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]],
*,
parallel: bool,
max_workers: Optional[int],
) -> List[Tuple[int, Tuple[str, ...], List[str]]]:
results: List[Tuple[int, Tuple[str, ...], List[str]]] = []
if parallel and len(tasks) > 1:
with ProcessPoolExecutor(max_workers=max_workers) as ex:
for idx, mix_keys, products_list in ex.map(_apply_rule_worker, tasks):
results.append((idx, mix_keys, products_list))
return results
for t in tasks:
results.append(_apply_rule_worker(t))
return results
def _integrate_results(
self,
results: List[Tuple[int, Tuple[str, ...], List[str]]],
pool_keys: Set[str],
step: int,
) -> Set[str]:
next_frontier: Set[str] = set()
for rule_index, mix_keys, products_list in results:
if not products_list:
continue
self._integrate_one_result(
rule_index=rule_index,
mix_keys=mix_keys,
products_list=products_list,
pool_keys=pool_keys,
next_frontier=next_frontier,
step=step,
)
return next_frontier
def _integrate_one_result(
self,
*,
rule_index: int,
mix_keys: Tuple[str, ...],
products_list: List[str],
pool_keys: Set[str],
next_frontier: Set[str],
step: int,
) -> None:
reactant_ids = [self._species_index.get(k) for k in mix_keys]
if any(nid is None for nid in reactant_ids):
return
reactant_ids_int = [int(nid) for nid in reactant_ids if nid is not None]
for prod_mix in products_list:
prod_mix = (prod_mix or "").strip()
if not prod_mix:
continue
products_raw = [s for s in prod_mix.split(".") if s]
r_keep, p_keep, r_keep_keys, p_keep_keys = self._delta_keep_ids(
reactant_ids_int,
products_raw,
)
if self.skip_no_change and (not r_keep and not p_keep):
continue
if (not self.allow_empty_side) and (not r_keep or not p_keep):
continue
if self.dedup_delta:
dkey = _dedup_key(
dedup_across_rules=self.dedup_across_rules,
rule_index=int(rule_index),
r_keep_keys=r_keep_keys,
p_keep_keys=p_keep_keys,
)
if dkey in self._seen_delta:
continue
self._seen_delta.add(dkey)
eid = self._add_rxn_event_node(step=step, rule_index=int(rule_index))
for rid in r_keep:
self.graph.add_edge(
rid,
eid,
step=step,
rule_index=int(rule_index),
rxn_id=eid,
role="reactant",
)
for pid in p_keep:
self.graph.add_edge(
eid,
pid,
step=step,
rule_index=int(rule_index),
rxn_id=eid,
role="product",
)
self._update_pool_with_products(products_raw, pool_keys, next_frontier)
def _update_pool_with_products(
self,
products_raw: List[str],
pool_keys: Set[str],
next_frontier: Set[str],
) -> None:
for p in products_raw:
pid_all = self._add_species_node(p)
if pid_all is None:
continue
pk = self.graph.nodes[pid_all].get(
"smiles_nomap", self.graph.nodes[pid_all].get("smiles")
)
if pk not in pool_keys:
pool_keys.add(pk)
next_frontier.add(pk)
# ------------------- convenience ------------------- #
@property
def species_nodes(self) -> List[int]:
return [n for n, d in self.graph.nodes(data=True) if d.get("kind") == "species"]
@property
def rxn_nodes(self) -> List[int]:
return [n for n, d in self.graph.nodes(data=True) if d.get("kind") == "rxn"]
# --------------------------------------------------------------------------- #
# Flattening
# --------------------------------------------------------------------------- #
[docs]
@dataclass
class ReactionDeltaFlattener:
graph: nx.DiGraph
skip_no_change: bool = True
allow_empty_side: bool = False
deduplicate: bool = True
_cache: List[Dict[str, Any]] = field(default_factory=list, init=False)
[docs]
def build(self) -> "ReactionDeltaFlattener":
self._cache = self._flatten()
return self
@property
def reactions(self) -> List[Dict[str, Any]]:
return list(self._cache)
def _collect_in(self, eid: int) -> List[str]:
out: List[str] = []
for u, _, ed in self.graph.in_edges(eid, data=True):
if ed.get("role") != "reactant":
continue
if self.graph.nodes[u].get("kind") != "species":
continue
out.append(self.graph.nodes[u].get("smiles", str(u)))
return out
def _collect_out(self, eid: int) -> List[str]:
out: List[str] = []
for _, v, ed in self.graph.out_edges(eid, data=True):
if ed.get("role") != "product":
continue
if self.graph.nodes[v].get("kind") != "species":
continue
out.append(self.graph.nodes[v].get("smiles", str(v)))
return out
def _nz(self, x: Optional[int]) -> int:
return 10**9 if x is None else int(x)
def _flatten(self) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
seen: Set[str] = set()
for eid, nd in self.graph.nodes(data=True):
if nd.get("kind") != "rxn":
continue
r = sorted(set(self._collect_in(eid)))
p = sorted(set(self._collect_out(eid)))
unchanged = set(r) & set(p)
rchg = [x for x in r if x not in unchanged]
pchg = [x for x in p if x not in unchanged]
if self.skip_no_change and not rchg and not pchg:
continue
if (not self.allow_empty_side) and (not rchg or not pchg):
continue
rxn_smiles = f"{'.'.join(rchg)}>>{'.'.join(pchg)}"
if self.deduplicate and rxn_smiles in seen:
continue
seen.add(rxn_smiles)
out.append(
{
"rxn_id": eid,
"label": nd.get("label"),
"step": nd.get("step"),
"rule_index": nd.get("rule_index"),
"app_index": nd.get("app_index"),
"reactants": rchg,
"products": pchg,
"rxn_smiles": rxn_smiles,
}
)
out.sort(
key=lambda r: (
self._nz(r.get("step")),
self._nz(r.get("rule_index")),
self._nz(r.get("app_index")),
r["rxn_id"],
)
)
return out
# --------------------------------------------------------------------------- #
# Convenience wrapper (line-length safe)
# --------------------------------------------------------------------------- #
[docs]
def build_syncrn_from_smarts(
rules: List[str],
seeds: List[str],
*,
repeats: int = 50,
explicit_h: bool = False,
implicit_temp: bool = False,
strategy: Optional[str] = None,
keep_aam: bool = True,
parallel: bool = False,
max_workers: Optional[int] = None,
max_components: int = 3,
use_frontier: bool = True,
max_mixtures_per_rule_step: int = 50_000,
max_tasks_per_step: int = 200_000,
skip_no_change: bool = True,
allow_empty_side: bool = False,
dedup_delta: bool = True,
dedup_across_rules: bool = False,
) -> nx.DiGraph:
crn = SynCRN(
rules=rules,
repeats=repeats,
explicit_h=explicit_h,
implicit_temp=implicit_temp,
strategy=strategy,
keep_aam=keep_aam,
max_components=max_components,
use_frontier=use_frontier,
max_mixtures_per_rule_step=max_mixtures_per_rule_step,
max_tasks_per_step=max_tasks_per_step,
skip_no_change=skip_no_change,
allow_empty_side=allow_empty_side,
dedup_delta=dedup_delta,
dedup_across_rules=dedup_across_rules,
)
return crn.build(seeds, parallel=parallel, max_workers=max_workers)