Source code for synkit.CRN.Construct.builder

from __future__ import annotations

from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple
import logging
from concurrent.futures import ProcessPoolExecutor

import networkx as nx

from .arity import infer_rule_arity
from .keys import make_dedup_key
from .smiles import Chem, standardize_smiles_rdkit
from .worker import apply_rule_worker
from .state import DerivationState
from .strategy import FrontierStrategy
from .derivation import DerivationLog

logger = logging.getLogger(__name__)


[docs] @dataclass class CRNExpand: rules: List[Any] repeats: int = 50 explicit_h: bool = False implicit_temp: bool = False strategy: Optional[str] = None keep_aam: bool = True max_components: int = 3 use_frontier: bool = True max_mixtures_per_rule_step: int = 50_000 max_tasks_per_step: int = 200_000 allow_self_mixtures: bool = False skip_no_change: bool = True allow_empty_side: bool = False dedup_delta: bool = True dedup_across_rules: bool = False graph: nx.DiGraph = field(init=False) _species_index: Dict[str, int] = field(init=False) _next_node_id: int = field(init=False) _smiles_cache: Dict[str, Optional[str]] = field(init=False) _app_counter: Dict[int, int] = field(init=False) _seen_attempts: Set[Tuple[int, Tuple[str, ...]]] = field(init=False) _seen_delta: Set[Tuple[Optional[int], Tuple[str, ...], Tuple[str, ...]]] = field( init=False ) _rule_arity_cache: Dict[int, int] = field(init=False) _warned_no_rdkit: bool = field(init=False) state: DerivationState = field(init=False, repr=False) strategy_engine: FrontierStrategy = field(init=False, repr=False) derivations: DerivationLog = field(init=False, repr=False) def __post_init__(self) -> None: if self.max_components < 1: raise ValueError("max_components must be >= 1") self.state = DerivationState() self.strategy_engine = FrontierStrategy() self.derivations = DerivationLog() self.reset()
[docs] def reset(self) -> None: self.graph = nx.DiGraph() self._species_index = {} self._next_node_id = 1 self._smiles_cache = {} self._app_counter = {} self._seen_attempts = set() self._seen_delta = set() self._rule_arity_cache = {} self._warned_no_rdkit = False self.state.set_initial(pool_keys=set(), frontier_keys=set()) self.derivations.clear()
def _alloc_node_id(self) -> int: nid = self._next_node_id self._next_node_id += 1 return nid def _next_app_index(self, *, rule_index: int) -> int: cur = self._app_counter.get(rule_index, 0) + 1 self._app_counter[rule_index] = cur return cur def _standardize_smiles(self, smiles: str) -> Optional[str]: if smiles in self._smiles_cache: return self._smiles_cache[smiles] out = standardize_smiles_rdkit(smiles, keep_aam=self.keep_aam) self._smiles_cache[smiles] = out if Chem is None and out is not None and not self._warned_no_rdkit: logger.warning( "RDKit not available; SMILES standardization is a passthrough." ) self._warned_no_rdkit = True return out def _standardize_product_mixture(self, prod_mix: str) -> List[str]: out: List[str] = [] for s in (prod_mix or "").split("."): if not s: continue std = self._standardize_smiles(s) if std is None: continue out.append(std) return out def _infer_rule_arity(self, rule: Any, rule_index: int) -> int: if rule_index in self._rule_arity_cache: return self._rule_arity_cache[rule_index] arity = infer_rule_arity(rule) self._rule_arity_cache[rule_index] = int(arity) return int(arity) def _add_species_node(self, smiles: str) -> Optional[int]: std = self._standardize_smiles(smiles) if std is None: return None if std in self._species_index: return self._species_index[std] nid = self._alloc_node_id() self._species_index[std] = nid self.graph.add_node( nid, kind="species", smiles=std, label=std, ) return nid def _add_rxn_event_node(self, *, step: int, rule_index: int) -> int: app_index = self._next_app_index(rule_index=rule_index) label = f"r@{rule_index}@{app_index}" eid = self._alloc_node_id() self.graph.add_node( eid, kind="rule", label=label, step=step, rule_index=rule_index, app_index=app_index, rule_repr=repr(self.rules[rule_index]), ) return eid def _delta_keep_smiles( self, reactant_ids: List[int], products_std_all: List[str], ) -> Tuple[List[int], List[str], Tuple[str, ...], Tuple[str, ...]]: r_smiles_all = [self.graph.nodes[rid]["smiles"] for rid in reactant_ids] unchanged = set(r_smiles_all) & set(products_std_all) r_keep_ids = [ rid for rid, rs in zip(reactant_ids, r_smiles_all) if rs not in unchanged ] p_keep_smiles = [ps for ps in products_std_all if ps not in unchanged] r_keep_keys = tuple( sorted(self.graph.nodes[rid]["smiles"] for rid in r_keep_ids) ) p_keep_keys = tuple(sorted(p_keep_smiles)) return r_keep_ids, p_keep_smiles, r_keep_keys, p_keep_keys def _iter_mixtures_for_rule( self, pool_keys: List[str], frontier_keys: List[str], *, arity: int, cap: int, ) -> Iterator[Tuple[str, ...]]: return self.strategy_engine.iter_mixtures( pool_keys=pool_keys, frontier_keys=frontier_keys, arity=arity, use_frontier=self.use_frontier, allow_self_mixtures=self.allow_self_mixtures, cap=cap, max_components=self.max_components, )
[docs] def build( self, seeds: Iterable[str], *, parallel: bool = False, max_workers: Optional[int] = None, reset: bool = True, ) -> nx.DiGraph: if reset: self.reset() pool_keys, frontier_keys = self._init_pool(seeds) self.state.set_initial(pool_keys=pool_keys, frontier_keys=frontier_keys) if not pool_keys: return self.graph for step in range(1, self.repeats + 1): self.state.begin_step(step) if self.use_frontier and not frontier_keys: break tasks = self._make_tasks_for_step(pool_keys, frontier_keys) if not tasks: break results = self._run_tasks(tasks, parallel=parallel, max_workers=max_workers) next_frontier = self._integrate_results(results, pool_keys, step) frontier_keys = next_frontier self.state.advance(next_frontier) if self.use_frontier and not frontier_keys: break return self.graph
def _init_pool(self, seeds: Iterable[str]) -> Tuple[Set[str], Set[str]]: pool_keys: Set[str] = set() frontier_keys: Set[str] = set() for s in seeds: sid = self._add_species_node(s) if sid is None: continue k = self.graph.nodes[sid]["smiles"] pool_keys.add(k) frontier_keys.add(k) return pool_keys, frontier_keys def _make_tasks_for_step( self, pool_keys: Set[str], frontier_keys: Set[str], ) -> List[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]]: budget = int(self.max_tasks_per_step) tasks: List[ Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]] ] = [] for ridx, rule in enumerate(self.rules): if budget <= 0: break arity = self._infer_rule_arity(rule, ridx) if arity > self.max_components: continue cap = min(int(self.max_mixtures_per_rule_step), budget) mix_iter = self._iter_mixtures_for_rule( list(pool_keys), list(frontier_keys), arity=arity, cap=cap, ) for mix_keys in mix_iter: if budget <= 0: break task = self._task_from_mix( ridx=ridx, rule=rule, mix_keys=mix_keys, ) if task is None: continue tasks.append(task) budget -= 1 return tasks def _task_from_mix( self, *, ridx: int, rule: Any, mix_keys: Tuple[str, ...], ) -> Optional[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]]: app_key = (int(ridx), mix_keys) if app_key in self._seen_attempts: return None self._seen_attempts.add(app_key) substrate = ".".join(mix_keys) return ( int(ridx), rule, substrate, self.explicit_h, self.implicit_temp, self.strategy, mix_keys, ) def _run_tasks( self, tasks: List[Tuple[int, Any, str, bool, bool, Optional[str], Tuple[str, ...]]], *, parallel: bool, max_workers: Optional[int], ) -> List[Tuple[int, Tuple[str, ...], List[str]]]: results: List[Tuple[int, Tuple[str, ...], List[str]]] = [] if parallel and len(tasks) > 1: with ProcessPoolExecutor(max_workers=max_workers) as ex: for idx, mix_keys, products_list in ex.map(apply_rule_worker, tasks): results.append((idx, mix_keys, products_list)) return results for t in tasks: results.append(apply_rule_worker(t)) return results def _integrate_results( self, results: List[Tuple[int, Tuple[str, ...], List[str]]], pool_keys: Set[str], step: int, ) -> Set[str]: next_frontier: Set[str] = set() for rule_index, mix_keys, products_list in results: if not products_list: continue self._integrate_one_result( rule_index=rule_index, mix_keys=mix_keys, products_list=products_list, pool_keys=pool_keys, next_frontier=next_frontier, step=step, ) return next_frontier def _integrate_one_result( self, *, rule_index: int, mix_keys: Tuple[str, ...], products_list: List[str], pool_keys: Set[str], next_frontier: Set[str], step: int, ) -> None: reactant_ids = [self._species_index.get(k) for k in mix_keys] if any(nid is None for nid in reactant_ids): return reactant_ids_int = [int(nid) for nid in reactant_ids if nid is not None] seen_prod_mix: Set[str] = set() for prod_mix in products_list: prod_mix = (prod_mix or "").strip() if not prod_mix or prod_mix in seen_prod_mix: continue seen_prod_mix.add(prod_mix) products_std_all = self._standardize_product_mixture(prod_mix) if not products_std_all: continue r_keep, p_keep_smiles, r_keep_keys, p_keep_keys = self._delta_keep_smiles( reactant_ids_int, products_std_all, ) if self.skip_no_change and (not r_keep and not p_keep_smiles): continue if (not self.allow_empty_side) and (not r_keep or not p_keep_smiles): continue if self.dedup_delta: dkey = make_dedup_key( dedup_across_rules=self.dedup_across_rules, rule_index=int(rule_index), r_keep_keys=r_keep_keys, p_keep_keys=p_keep_keys, ) if dkey in self._seen_delta: continue self._seen_delta.add(dkey) eid = self._add_rxn_event_node(step=step, rule_index=int(rule_index)) nd = self.graph.nodes[eid] self._add_reactant_edges( reactant_ids=r_keep, eid=eid, step=step, rule_index=int(rule_index), ) self._add_product_edges( product_smiles=p_keep_smiles, eid=eid, step=step, rule_index=int(rule_index), ) self.derivations.append( event_id=eid, label=str(nd.get("label")), step=int(step), rule_index=int(rule_index), reactants=tuple( sorted(self.graph.nodes[rid]["smiles"] for rid in r_keep) ), products=tuple(sorted(p_keep_smiles)), ) self._update_pool_with_products( products_std_all=products_std_all, pool_keys=pool_keys, next_frontier=next_frontier, ) def _add_reactant_edges( self, *, reactant_ids: List[int], eid: int, step: int, rule_index: int, ) -> None: for rid, stoich in Counter(reactant_ids).items(): self.graph.add_edge( rid, eid, step=step, rule_index=rule_index, rxn_id=eid, role="reactant", stoich=int(stoich), ) def _add_product_edges( self, *, product_smiles: List[str], eid: int, step: int, rule_index: int, ) -> None: for ps, stoich in Counter(product_smiles).items(): pid = self._add_species_node(ps) if pid is None: continue self.graph.add_edge( eid, pid, step=step, rule_index=rule_index, rxn_id=eid, role="product", stoich=int(stoich), ) def _update_pool_with_products( self, *, products_std_all: List[str], pool_keys: Set[str], next_frontier: Set[str], ) -> None: for ps in set(products_std_all): pid = self._add_species_node(ps) if pid is None: continue pk = self.graph.nodes[pid]["smiles"] if pk not in pool_keys: pool_keys.add(pk) next_frontier.add(pk) @property def species_nodes(self) -> List[int]: return [n for n, d in self.graph.nodes(data=True) if d.get("kind") == "species"] @property def rxn_nodes(self) -> List[int]: return [n for n, d in self.graph.nodes(data=True) if d.get("kind") == "rule"] @property def derivation_records(self) -> List[Dict[str, object]]: return self.derivations.as_dicts()
[docs] def build_crn_from_smarts( rules: List[str], seeds: List[str], *, repeats: int = 50, explicit_h: bool = False, implicit_temp: bool = False, strategy: Optional[str] = None, keep_aam: bool = True, parallel: bool = False, max_workers: Optional[int] = None, max_components: int = 3, use_frontier: bool = True, max_mixtures_per_rule_step: int = 50_000, max_tasks_per_step: int = 200_000, allow_self_mixtures: bool = False, skip_no_change: bool = True, allow_empty_side: bool = False, dedup_delta: bool = True, dedup_across_rules: bool = False, ) -> nx.DiGraph: crn = CRNExpand( rules=rules, repeats=repeats, explicit_h=explicit_h, implicit_temp=implicit_temp, strategy=strategy, keep_aam=keep_aam, max_components=max_components, use_frontier=use_frontier, max_mixtures_per_rule_step=max_mixtures_per_rule_step, max_tasks_per_step=max_tasks_per_step, allow_self_mixtures=allow_self_mixtures, skip_no_change=skip_no_change, allow_empty_side=allow_empty_side, dedup_delta=dedup_delta, dedup_across_rules=dedup_across_rules, ) return crn.build(seeds, parallel=parallel, max_workers=max_workers)