Source code for synkit.Synthesis.Reactor.post_syn

import logging
from typing import Iterable, List, Dict, Any, Callable, Optional

from joblib import Parallel, delayed
import joblib

# optional progress bar
try:
    from tqdm.auto import tqdm
except ImportError:  # tqdm not installed
    tqdm = None

from synkit.Chem.Reaction.standardize import Standardize
from synkit.Chem.utils import clean_radical_rsmi


class _TqdmJoblib:
    def __init__(self, tqdm_obj):
        self.tqdm_obj = tqdm_obj
        self._old_batch_callback = None

    def __enter__(self):
        self._old_batch_callback = joblib.parallel.BatchCompletionCallBack

        class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):  # type: ignore[valid-type]
            def __call__(self_inner, *args, **kwargs):
                try:
                    self.tqdm_obj.update(n=self_inner.batch_size)
                except Exception:
                    pass
                return super(TqdmBatchCompletionCallback, self_inner).__call__(
                    *args, **kwargs
                )

        joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback  # type: ignore[attr-defined]
        return self.tqdm_obj

    def __exit__(self, exc_type, exc_val, exc_tb):
        joblib.parallel.BatchCompletionCallBack = self._old_batch_callback
        self.tqdm_obj.close()


[docs] class PostSyn: """ Post-processing helper for reaction data: standardize reactions and clean AAM strings, with optional parallelism, progress reporting, and filtering of incomplete reaction SMILES inside fw/bw lists. Input keys for reaction, fw, and bw are configurable. """ def __init__( self, n_jobs: int = 1, verbose: int = 2, standardizer: Optional[Standardize] = None, reaction_key: str = "reactions", fw_key: str = "fw", bw_key: str = "bw", ) -> None: """ :param n_jobs: number of parallel jobs. :param verbose: verbosity level (0=errors only,1=warnings,2=info,3=debug). :param standardizer: optional Standardize instance. :param reaction_key: key in input dict holding the reaction SMILES. :param fw_key: key for forward AAMs/reaction SMILES list. :param bw_key: key for backward AAMs/reaction SMILES list. """ self.n_jobs = n_jobs self.verbose = verbose self.std = standardizer if standardizer is not None else Standardize() self.reaction_key = reaction_key self.fw_key = fw_key self.bw_key = bw_key self.logger = logging.getLogger(self.__class__.__name__) self._configure_logger() def _configure_logger(self): if not self.logger.hasHandlers(): handler = logging.StreamHandler() fmt = "%(asctime)s [%(levelname)s] %(name)s: %(message)s" handler.setFormatter(logging.Formatter(fmt)) self.logger.addHandler(handler) if self.verbose >= 3: self.logger.setLevel(logging.DEBUG) elif self.verbose == 2: self.logger.setLevel(logging.INFO) elif self.verbose == 1: self.logger.setLevel(logging.WARNING) else: self.logger.setLevel(logging.ERROR)
[docs] def clean_aam( self, list_aam: Iterable[str], remove_radical: bool = True ) -> List[str]: """ Remove atom-atom mappings, optionally clean radicals, deduplicate while preserving order. """ seen = set() unique = [] for smi in list_aam: try: no_aam = self.std.fit(smi, remove_aam=True) if remove_radical: no_aam = clean_radical_rsmi(no_aam) if no_aam and no_aam not in seen: seen.add(no_aam) unique.append(no_aam) except Exception as exc: self.logger.debug( "Failed to clean AAM for %r: %s", smi, exc, exc_info=self.verbose >= 3, ) return unique
def _rxn_has_both_sides(self, rxn: Any) -> bool: if not isinstance(rxn, str): return False if ">>" in rxn: left, right = rxn.split(">>", 1) return bool(left.strip()) and bool(right.strip()) parts = rxn.split(">") if len(parts) == 3: reactants, _, products = parts return bool(reactants.strip()) and bool(products.strip()) return False def _process_single(self, value: Dict[str, Any]) -> Dict[str, Any]: """ Worker: standardize reaction and clean fw/bw lists based on configured keys. """ result = dict(value) # shallow copy # Standardize reaction raw_rxn = value.get(self.reaction_key) if raw_rxn is None: self.logger.warning( "Missing reaction key '%s' in value: %r", self.reaction_key, value ) result["std_rxn"] = None else: try: result["std_rxn"] = self.std.fit(raw_rxn) except Exception as exc: self.logger.warning("Standardization failed for %r: %s", raw_rxn, exc) result["std_rxn"] = None # Clean fw list try: raw_fw = value.get(self.fw_key, []) cleaned_fw = self.clean_aam(raw_fw) result[self.fw_key] = cleaned_fw except Exception as exc: self.logger.debug( "Failed cleaning fw (%s) for %r: %s", self.fw_key, value, exc, exc_info=self.verbose >= 3, ) result[self.fw_key] = [] # Clean bw list try: raw_bw = value.get(self.bw_key, []) cleaned_bw = self.clean_aam(raw_bw) result[self.bw_key] = cleaned_bw except Exception as exc: self.logger.debug( "Failed cleaning bw (%s) for %r: %s", self.bw_key, value, exc, exc_info=self.verbose >= 3, ) result[self.bw_key] = [] return result def _run_serial( self, seq: List[Dict[str, Any]], progress: bool ) -> List[Dict[str, Any]]: results = [] iterator = seq if progress: if tqdm: iterator = tqdm(seq, desc="PostSyn", unit="item") else: self.logger.info("tqdm not available; running without progress bar.") for v in iterator: results.append(self._process_single(v)) return results def _run_parallel( self, seq: List[Dict[str, Any]], progress: bool ) -> List[Dict[str, Any]]: if progress and tqdm: try: with _TqdmJoblib(tqdm(total=len(seq), desc="PostSyn", unit="item")): return Parallel( n_jobs=self.n_jobs, backend="loky", verbose=self.verbose )(delayed(self._process_single)(v) for v in seq) except Exception as e: self.logger.warning( "Progress-wrapped parallel execution failed (%s); falling back to normal Parallel.", e, ) elif progress and not tqdm: self.logger.info( "tqdm not installed; running parallel without progress bar." ) return Parallel(n_jobs=self.n_jobs, backend="loky", verbose=self.verbose)( delayed(self._process_single)(v) for v in seq ) def _filter_fw_bw(self, results: List[Dict[str, Any]]) -> None: check = self._rxn_has_both_sides for value in results: fw = value.get(self.fw_key, []) bw = value.get(self.bw_key, []) filtered_fw = [r for r in fw if check(r)] filtered_bw = [r for r in bw if check(r)] if self.verbose >= 3: dropped_fw = len(fw) - len(filtered_fw) dropped_bw = len(bw) - len(filtered_bw) if dropped_fw: self.logger.debug( "Dropped %d incomplete fw entries for record %r", dropped_fw, fw ) if dropped_bw: self.logger.debug( "Dropped %d incomplete bw entries for record %r", dropped_bw, bw ) value[self.fw_key] = filtered_fw value[self.bw_key] = filtered_bw
[docs] def process( self, data: Iterable[Dict[str, Any]], *, progress: bool = False, prefilter: Optional[Callable[[Dict[str, Any]], bool]] = None, filter_incomplete_rxn: bool = True, ) -> List[Dict[str, Any]]: """ Process reaction entries. :param data: iterable of dicts. :param progress: show progress bar if True. :param prefilter: predicate to pre-filter entries. :param filter_incomplete_rxn: if True, drop incomplete SMILES inside fw/bw. :return: processed list with standardized reaction and cleaned fw/bw under their original keys. """ seq = list(data) if prefilter: seq = [v for v in seq if prefilter(v)] if self.n_jobs == 1: results = self._run_serial(seq, progress) else: results = self._run_parallel(seq, progress) if filter_incomplete_rxn: self._filter_fw_bw(results) return results