Source code for synkit.Rule.Apply.retro_reactor

import heapq
import importlib.util
from typing import Dict, List, Tuple


from synkit.IO.chem_converter import gml_to_smart
from synkit.Rule.Modify.molecule_rule import MoleculeRule
from synkit.Chem.utils import (
    get_sanitized_smiles,
    remove_duplicates,
    filter_smiles,
    count_carbons,
)

if importlib.util.find_spec("mod"):
    from mod import ruleGMLString, RCMatch
else:
    ruleGMLString = None
    RCMatch = None
    print("Optional 'mod' package not found")


[docs] class RetroReactor: def __init__(self) -> None: """Initialize the RuleFrag class with caches and null initial values. Attributes: - backward_cache: A dictionary cache (keyed by (smiles, rule)) to avoid redundant computations. """ self.backward_cache: Dict[Tuple[str, str], List[str]] = {} def _apply_backward(self, smiles: str, rule: str) -> List[str]: """Apply a transformation rule in backward mode to a SMILES string, returning possible precursors. Uses caching to avoid redundant computations. Parameters: - smiles (str): SMILES string to transform. - rule (str): Transformation rule. Returns: - List[str]: List of possible precursor SMILES strings. """ cache_key = (smiles, rule) if cache_key in self.backward_cache: return self.backward_cache[cache_key] # Convert rule to GML in backward mode rule_str = ruleGMLString(rule, invert=True, add=False) mol_rule = MoleculeRule().generate_molecule_rule(smiles) mol_rule_str = ruleGMLString(mol_rule, add=False) matcher = RCMatch(mol_rule_str, rule_str) mod_results = matcher.composeAll() results_set = set() for match_rule in mod_results: # In user-defined backward mode, "reactants" # appear in smarts.split(">>")[1]. smarts = gml_to_smart(match_rule.getGMLString(), sanitize=False) reactants = smarts.split(">>")[1].split(".") reactants = get_sanitized_smiles(reactants) results_set.update(reactants) # Filter out SMILES that are invalid relative to the original results_list = filter_smiles(results_set, smiles) results_list = remove_duplicates(results_list) self.backward_cache[cache_key] = list(results_list) return self.backward_cache[cache_key] def _heuristic(self, current_smiles: str, precursor_smiles: str) -> int: """Heuristic function for A* search. Here, we define the "distance" as the absolute difference in the carbon count between the current SMILES and the known precursor SMILES. Parameters: - current_smiles (str): The SMILES of the node being expanded. - precursor_smiles (str): The SMILES of the known precursor (our target). Returns: - int: Estimated cost (distance) based on difference in carbon count. """ return abs(count_carbons(current_smiles) - count_carbons(precursor_smiles))