Source code for synkit.Synthesis.Reactor.batch_reactor

import logging
from typing import Any, Dict, Iterable, List, Optional, Union

import networkx as nx
from joblib import Parallel, delayed

from synkit.IO import smiles_to_graph, rsmi_to_its
from synkit.Synthesis.Reactor.mod_reactor import MODReactor
from synkit.Synthesis.Reactor.rule_filter import RuleFilter
from synkit.Synthesis.Reactor.syn_reactor import SynReactor

__all__ = ["BatchReactor"]

# =============================================================================
# Low-level rule application
# =============================================================================


def _apply_rule_raw(
    substrate: nx.Graph,
    rule: nx.Graph,
    invert: bool,
    engine: str,
    *,
    strategy: str,
    explicit_h: bool,
    implicit_temp: bool,
) -> List[str]:
    """
    Apply **one** rule graph to a substrate graph using the specified reactor engine.

    :param substrate: Graph representing the substrate molecule.
    :type substrate: networkx.Graph
    :param rule: Graph representing the reaction rule (ITS template).
    :type rule: networkx.Graph
    :param invert: Whether to apply the rule in reverse.
    :type invert: bool
    :param engine: Which reactor engine to use (`"syn"` or `"mod"`).
    :type engine: str
    :param strategy: Matching strategy passed to SynReactor.
    :type strategy: str
    :param explicit_h: Use explicit hydrogens in SynReactor.
    :type explicit_h: bool
    :param implicit_temp: Use implicit templates in SynReactor.
    :type implicit_temp: bool
    :returns: A list of product SMARTS strings or reaction SMILES.
    :rtype: list of str
    """
    try:
        if engine == "syn":
            reactor = SynReactor(
                substrate=substrate,
                template=rule,
                invert=invert,
                strategy=strategy,
                explicit_h=explicit_h,
                implicit_temp=implicit_temp,
            )
            return list(getattr(reactor, "smarts_list", getattr(reactor, "smarts", [])))

        reactor = MODReactor(
            substrate=substrate,
            rule=rule,
            invert=invert,
            strategy=strategy,
        )
        return [reactor.run().get_reaction_smiles()]

    except Exception as exc:  # pragma: no cover
        logging.getLogger(__name__).debug(
            "%s reactor failed (invert=%s): %s", engine, invert, exc
        )
        return []


# =============================================================================
# Picklable rule-applier with small FIFO cache
# =============================================================================


class _RuleApplier:
    """
    Callable wrapper around `_apply_rule_raw` with a FIFO cache.

    The cache is a per-process dict keyed by (substrate_id, rule_id, invert).
    """

    __slots__ = (
        "_engine",
        "_strategy",
        "_explicit_h",
        "_implicit_temp",
        "_cache",
        "_cache_max",
    )

    def __init__(
        self,
        engine: str,
        *,
        strategy: str,
        explicit_h: bool,
        implicit_temp: bool,
        cache_enabled: bool,
        cache_maxsize: int,
    ) -> None:
        """
        Initialize the rule applier.

        :param engine: Reactor engine, 'syn' or 'mod'.
        :param strategy: Matching strategy for SynReactor.
        :param explicit_h: Pass explicit hydrogens to SynReactor.
        :param implicit_temp: Pass implicit template flag to SynReactor.
        :param cache_enabled: Whether to enable in-process caching.
        :param cache_maxsize: Maximum cache entries before eviction.
        """
        self._engine = engine
        self._strategy = strategy
        self._explicit_h = explicit_h
        self._implicit_temp = implicit_temp
        self._cache = {} if cache_enabled else None  # type: ignore[var-annotated]
        self._cache_max = cache_maxsize

    def _execute(
        self,
        substrate: nx.Graph,
        rule: nx.Graph,
        inv: bool,
    ) -> List[str]:
        """Directly call `_apply_rule_raw` without caching."""
        return _apply_rule_raw(
            substrate,
            rule,
            inv,
            self._engine,
            strategy=self._strategy,
            explicit_h=self._explicit_h,
            implicit_temp=self._implicit_temp,
        )

    def __call__(
        self,
        substrate: nx.Graph,
        rule: nx.Graph,
        inv: bool,
    ) -> List[str]:
        """
        Apply a rule to a substrate, using the cache if enabled.

        :param substrate: Substrate graph.
        :param rule: Rule graph.
        :param inv: Inversion flag.
        :returns: List of reaction outputs.
        """
        if self._cache is None:
            return self._execute(substrate, rule, inv)

        key = (id(substrate), id(rule), inv)
        if key in self._cache:
            return self._cache[key]

        res = self._execute(substrate, rule, inv)
        if len(self._cache) >= self._cache_max:
            self._cache.pop(next(iter(self._cache)))  # FIFO eviction
        self._cache[key] = res
        return res


# =============================================================================
# Helper utilities
# =============================================================================


def _dedupe(items: Iterable[Any]) -> List[Any]:
    """Remove duplicates while preserving order."""
    seen: set[Any] = set()
    out: List[Any] = []
    for x in items:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


# =============================================================================
# Public API
# =============================================================================


[docs] class BatchReactor: """ Parallel, cache-aware batch application of reaction rules to SMILES substrates. :param data: List of SMILES strings or dicts containing SMILES under `host_key`. :type data: list of str or dict :param host_key: Key to extract SMILES from dict entries (optional). :type host_key: str or None :param react_engine: Reactor engine: 'syn' or 'mod'. :type react_engine: str :param pre_filter_engine: Pre-filtering engine for rules (None to skip). :type pre_filter_engine: str or None :param explicit_h: Use explicit hydrogens in SynReactor. :type explicit_h: bool :param implicit_temp: Use implicit templates in SynReactor. :type implicit_temp: bool :param strategy: Matching strategy for SynReactor. :type strategy: str :param dedupe: Deduplicate results per-substrate. :type dedupe: bool :param entry_n_jobs: Number of parallel jobs for substrates. :type entry_n_jobs: int :param rule_n_jobs: Number of parallel jobs for rules per substrate. :type rule_n_jobs: int :param parallel_rules: Enable parallelism over rules. :type parallel_rules: bool :param allow_nested: Allow nested parallelism. :type allow_nested: bool :param cache_enabled: Enable in-process per-rule caching. :type cache_enabled: bool :param cache_maxsize: Max entries in per-process cache. :type cache_maxsize: int :param logger: Optional custom logger. :type logger: logging.Logger or None :raises ValueError: If react_engine is invalid or SMILES/rule conversion fails. """ def __init__( self, data: List[Union[str, Dict[str, Any]]], host_key: Optional[str] = None, *, react_engine: str = "syn", pre_filter_engine: Optional[str] = None, explicit_h: bool = True, implicit_temp: bool = False, strategy: str = "bt", dedupe: bool = True, entry_n_jobs: int = 1, rule_n_jobs: int = 1, parallel_rules: bool = False, allow_nested: bool = False, cache_enabled: bool = True, cache_maxsize: int = 32768, logger: Optional[logging.Logger] = None, enable_logging: bool = True, ) -> None: """ Initialize the BatchReactor. See class docstring for parameter details. """ if react_engine.lower() not in {"syn", "mod"}: raise ValueError("react_engine must be 'syn' or 'mod'") self._data = list(data) self._host_key = host_key self._engine = react_engine.lower() self._pre_filter = pre_filter_engine.lower() if pre_filter_engine else None self._explicit_h = explicit_h self._implicit_temp = implicit_temp self._strategy = strategy self._dedupe = dedupe self._entry_jobs = max(1, entry_n_jobs) self._rule_jobs = max(1, rule_n_jobs) self._parallel_rules = parallel_rules self._allow_nested = allow_nested self._log = logger or logging.getLogger(__name__) if enable_logging is False: logging.disable(logging.CRITICAL) self._apply_rule = _RuleApplier( self._engine, strategy=strategy, explicit_h=explicit_h, implicit_temp=implicit_temp, cache_enabled=cache_enabled, cache_maxsize=cache_maxsize, )
[docs] def help(self) -> str: """ Return usage examples and API description. :returns: Multi-line help text. :rtype: str """ return ( "BatchReactor(data, host_key=None, react_engine='syn', pre_filter_engine=None, " "explicit_h=True, implicit_temp=False, strategy='bt', dedupe=True, " "entry_n_jobs=1, rule_n_jobs=1, parallel_rules=False, allow_nested=False, " "cache_enabled=True, cache_maxsize=32768, logger=None)" + "\n" "Use .fit(rules, invert=False) to apply reaction rules to your SMILES batch." )
[docs] def describe(self) -> str: """ Return a configuration summary. :returns: Human-readable settings overview. :rtype: str """ lines = [ "BatchReactor configuration:", f" entries : {len(self._data)}", f" engine : {self._engine}", f" pre_filter : {self._pre_filter}", f" explicit_h : {self._explicit_h}", f" implicit_temp : {self._implicit_temp}", f" strategy : {self._strategy}", f" dedupe : {self._dedupe}", f" entry_n_jobs : {self._entry_jobs}", f" rule_n_jobs : {self._rule_jobs}", f" parallel_rules : {self._parallel_rules}", f" allow_nested : {self._allow_nested}", f" cache_enabled : {self._apply_rule._cache is not None}", f" cache_maxsize : {self._apply_rule._cache_max}", ] return "\n".join(lines)
def __len__(self) -> int: """Return the number of substrates in the batch.""" return len(self._data) def __iter__(self): """Yield each entry from the batch data.""" yield from self._data def __getitem__(self, idx: int) -> Union[str, Dict[str, Any]]: """Get the data entry at index `idx`.""" return self._data[idx] def __repr__(self) -> str: """Concise summary of the BatchReactor instance.""" return ( f"<BatchReactor entries={len(self._data)} engine={self._engine} " f"pre_filter={self._pre_filter} entry_jobs={self._entry_jobs} " f"rule_jobs={self._rule_jobs} parallel_rules={self._parallel_rules}>" )
[docs] def fit( self, rules: Iterable[Any], *, invert: bool = False, ) -> List[Dict[str, Any]]: """ Apply reaction rules to each substrate in the batch. :param rules: Iterable of rule graphs or SMILES strings. :type rules: iterable :param invert: Whether to apply rules in reverse direction. :type invert: bool :returns: A list of dicts, each with: - "out": list of product SMARTS/reaction SMILES - "count": number of outputs :rtype: list of dict """ rule_graphs = self._ensure_graph_rules(rules) direction = "bw" if invert else "fw" def worker(entry: Union[str, Dict[str, Any]]) -> Dict[str, Any]: g = self._to_graph(entry) filtered = ( RuleFilter(g, rule_graphs, invert=invert, engine=self._pre_filter).new_rules # type: ignore[arg-type] if self._pre_filter and self._engine == "syn" else rule_graphs ) out = self._apply_bulk(g, filtered, invert) return {f"{self._engine}_{direction}": out, "count": len(out)} if self._entry_jobs == 1: return [worker(e) for e in self._data] return Parallel( n_jobs=self._entry_jobs, backend="loky", prefer="processes", )(delayed(worker)(e) for e in self._data)
# ------------------------------------------------------------------ # Additional helpers # ------------------------------------------------------------------ def _to_graph(self, entry: Union[str, Dict[str, Any]]) -> nx.Graph: """ Convert a SMILES string or dict entry into a networkx Graph. :param entry: SMILES or dict holding SMILES under host_key. :type entry: str or dict :returns: Molecule graph. :rtype: networkx.Graph :raises KeyError: if host_key is missing for dict. :raises TypeError: if entry is not str or dict. :raises ValueError: if SMILES conversion fails. """ if isinstance(entry, dict): if self._host_key is None: raise KeyError("host_key missing for dict entry") entry = entry[self._host_key] if not isinstance(entry, str): raise TypeError("Each entry must be a SMILES string or dict containing one") g = smiles_to_graph(entry, drop_non_aam=False, use_index_as_atom_map=False) if g is None: raise ValueError(f"Invalid SMILES: {entry}") return g @staticmethod def _ensure_graph_rules(rules: Iterable[Any]) -> List[nx.Graph]: """ Convert an iterable of SMILES or graph rules into graph objects. :param rules: Iterable of rule SMILES or networkx.Graph. :type rules: iterable :returns: List of rule graphs. :rtype: list of networkx.Graph :raises ValueError: if a rule SMILES is invalid. :raises TypeError: for unsupported rule types. """ out: List[nx.Graph] = [] for r in rules: if isinstance(r, str): g = rsmi_to_its(r, core=True) if g is None: raise ValueError(f"Invalid rule SMILES: {r}") out.append(g) elif isinstance(r, nx.Graph): out.append(r) else: raise TypeError("Rules must be str or nx.Graph") return out def _apply_bulk( self, g: nx.Graph, rules: List[nx.Graph], invert: bool ) -> List[str]: """ Apply a list of rule graphs to a substrate graph. :param g: Substrate graph. :type g: networkx.Graph :param rules: List of rule graphs. :type rules: list of networkx.Graph :param invert: Direction flag. :type invert: bool :returns: Flattened list of reaction outputs. :rtype: list of str """ jobs = 1 if ( self._parallel_rules and self._rule_jobs > 1 and (self._allow_nested or self._entry_jobs == 1) ): jobs = self._rule_jobs if jobs == 1: nested = [self._apply_rule(g, r, invert) for r in rules] else: nested = Parallel(n_jobs=jobs, backend="loky", prefer="processes")( delayed(self._apply_rule)(g, r, invert) for r in rules ) flat = [x for sub in nested for x in sub] return _dedupe(flat) if self._dedupe else flat