Source code for synkit.Synthesis.Reactor.benchmark

import logging
from typing import Any, Dict, Iterable, List, Optional
from synkit.Synthesis.Reactor.batch_reactor import BatchReactor

# =============================================================================
# Benchmark subclass
# =============================================================================


[docs] class Benchmark(BatchReactor): # pylint: disable=too-many-arguments """ Extension of BatchReactor to benchmark forward/backward application on reaction-SMILES entries. :param data: List of dicts containing reaction SMILES under `reaction_key`. :type data: list of dict :param reaction_key: Key for reaction-SMILES strings (format 'reactants>>products'). :type reaction_key: str :param react_engine: Reactor engine: 'syn' or 'mod'. :type react_engine: str :param pre_filter_engine: Pre-filtering engine for rules (None to skip). :type pre_filter_engine: str or None :param explicit_h: Use explicit hydrogens in SynReactor. :type explicit_h: bool :param implicit_temp: Use implicit templates in SynReactor. :type implicit_temp: bool :param strategy: Matching strategy for SynReactor. :type strategy: str :param dedupe: Deduplicate results per substrate. :type dedupe: bool :param entry_n_jobs: Parallel jobs for substrates. :type entry_n_jobs: int :param rule_n_jobs: Parallel jobs for rules per substrate. :type rule_n_jobs: int :param parallel_rules: Enable rule-level parallelism. :type parallel_rules: bool :param allow_nested: Allow nested parallelism. :type allow_nested: bool :param cache_enabled: Enable per-process caching. :type cache_enabled: bool :param cache_maxsize: Max cache entries before eviction. :type cache_maxsize: int :param logger: Optional custom logger. :type logger: logging.Logger or None :raises ValueError: If reaction_key entry malformed or SMILES invalid. """ def __init__( self, data: List[Dict[str, Any]], reaction_key: str = "reactions", *, react_engine: str = "syn", pre_filter_engine: Optional[str] = None, explicit_h: bool = True, implicit_temp: bool = False, strategy: str = "bt", dedupe: bool = True, entry_n_jobs: int = 1, rule_n_jobs: int = 1, parallel_rules: bool = False, allow_nested: bool = False, cache_enabled: bool = True, cache_maxsize: int = 32768, logger: Optional[logging.Logger] = None, enable_logging: bool = True, ) -> None: """ Initialize Benchmark with reaction entries. Splits each reaction-SMILES into reactant 'r' and product 'p'. All other parameters mirror BatchReactor (host_key set to 'r'). """ data_prepped = self._get_host(data, reaction_key) super().__init__( data_prepped, host_key="r", react_engine=react_engine, pre_filter_engine=pre_filter_engine, explicit_h=explicit_h, implicit_temp=implicit_temp, strategy=strategy, dedupe=dedupe, entry_n_jobs=entry_n_jobs, rule_n_jobs=rule_n_jobs, parallel_rules=parallel_rules, allow_nested=allow_nested, cache_enabled=cache_enabled, cache_maxsize=cache_maxsize, logger=logger, enable_logging=enable_logging, ) @staticmethod def _get_host( data: List[Dict[str, Any]], reaction_key: str = "reactions" ) -> List[Dict[str, Any]]: """ Populate 'r' and 'p' SMILES fields for each dict entry. :param data: List of dict entries. :type data: list of dict :param reaction_key: Key for reaction-SMILES string. :type reaction_key: str :returns: Same list with 'r' and 'p' keys added. :rtype: list of dict :raises ValueError: If any string lacks '>>'. """ for entry in data: rxn = entry.get(reaction_key) if not isinstance(rxn, str) or ">>" not in rxn: raise ValueError( f"Invalid reaction string for key '{reaction_key}': {rxn!r}" ) r, p = rxn.split(">>", 1) entry["r"], entry["p"] = r, p return data
[docs] def fit( self, rules: Iterable[Any], ) -> List[Dict[str, Any]]: """ Perform forward (invert=False) on 'r' and backward (invert=True) on 'p'. :param rules: Iterable of rule graphs or SMILES. :type rules: iterable :param reaction_key: Key for reaction-SMILES (unused here). :type reaction_key: str :returns: List of dicts each with keys 'fw','bw','fw_count','bw_count'. :rtype: list of dict """ fw_out = super().fit(rules, invert=False) self._host_key = "p" bw_out = super().fit(rules, invert=True) self._host_key = "r" for entry, fw, bw in zip(self._data, fw_out, bw_out): entry["fw"] = fw.get(f"{self._engine}_fw", fw.get("out", fw)) entry["fw_count"] = fw.get("count", len(entry["fw"])) entry["bw"] = bw.get(f"{self._engine}_bw", bw.get("out", bw)) entry["bw_count"] = bw.get("count", len(entry["bw"])) return self._data
[docs] def describe(self) -> str: """ Return detailed configuration for Benchmark, including reaction_key. :returns: Multi-line summary. :rtype: str """ base = super().describe().splitlines() return "".join(base + ["Benchmark host_key: 'r' (product host: 'p')"])