Source code for synkit.Graph.ITS.its_relabel

from __future__ import annotations
from typing import Any, Dict, Iterable, List, Tuple
from rdkit import Chem

from synkit.Graph.syn_graph import SynGraph
from synkit.IO import setup_logging
from synkit.IO.graph_to_mol import GraphToMol
from synkit.IO.chem_converter import smiles_to_graph
from synkit.Graph.Matcher.graph_morphism import find_graph_isomorphism

logger = setup_logging("INFO")


[docs] class ITSRelabel: """Extend reaction SMILES through atom-map alignment between reactant and product SynGraphs. :cvar logger: Logger instance for debug and info messages. :type logger: logging.Logger :ivar graph_to_mol: Converter from SynGraph to RDKit Mol. :type graph_to_mol: GraphToMol """ def __init__(self) -> None: """Initialize ITSRelabel with default GraphToMol converter.""" self.graph_to_mol = GraphToMol() @staticmethod def _get_nodes_with_atom_map(graph: SynGraph) -> List[Any]: """Extract node IDs with a non-zero atom_map attribute from a SynGraph. :param graph: Input SynGraph with 'atom_map' on nodes. :type graph: SynGraph :returns: List of node identifiers where 'atom_map' != 0. :rtype: List[Any] """ return [ n for n, attrs in graph.raw.nodes(data=True) if attrs.get("atom_map", 0) != 0 ] @staticmethod def _remove_internal_edges(graph: SynGraph, nodes: List[Any]) -> SynGraph: """Remove edges connecting nodes in the given list from a SynGraph. :param graph: Input SynGraph to prune. :type graph: SynGraph :param nodes: Node IDs whose connecting edges will be removed. :type nodes: List[Any] :returns: A new SynGraph with specified edges removed. :rtype: SynGraph """ G_copy = SynGraph(graph.raw.copy()) node_set = set(nodes) edges_to_remove = [ (u, v) for u, v in G_copy.raw.edges() if u in node_set and v in node_set ] G_copy.raw.remove_edges_from(edges_to_remove) return G_copy @staticmethod def _dict_to_tuple_list( mapping: Dict[Any, Any], sort_by_key: bool = False, sort_by_value: bool = False ) -> List[Tuple[Any, Any]]: """Convert a mapping dict into a sorted list of tuples. :param mapping: Dictionary to convert. :type mapping: Dict[Any, Any] :param sort_by_key: Sort tuples by key if True. :type sort_by_key: bool :param sort_by_value: Sort tuples by value if True. :type sort_by_value: bool :returns: List of (key, value) tuples. :rtype: List[Tuple[Any, Any]] """ items = list(mapping.items()) if sort_by_key: items.sort(key=lambda x: x[0]) elif sort_by_value: items.sort(key=lambda x: x[1]) return items def _update_mapping( self, G: SynGraph, H: SynGraph, mapping: Iterable[Tuple[Any, Any]], aam_key: str = "atom_map", ) -> Tuple[SynGraph, SynGraph]: """Update node attributes in two SynGraphs based on a sequential mapping. This method resets the specified atom-map attribute for all nodes in both graphs to 0, then assigns a new atom-map value (i+1) for each mapped pair: G.nodes[g_node][aam_key] = i + 1 H.nodes[h_node][aam_key] = i + 1 :param G: First SynGraph to update (reactant). :type G: SynGraph :param H: Second SynGraph to update (product). :type H: SynGraph :param mapping: Iterable of (g_node, h_node) tuples defining node correspondence. :type mapping: Iterable[Tuple[Any, Any]] :param aam_key: Name of the atom-map attribute on each node. :type aam_key: str :returns: Tuple of updated SynGraphs (G_updated, H_updated). :rtype: Tuple[SynGraph, SynGraph] """ # Create deep copies of raw graphs G_copy = SynGraph(G.raw.copy()) H_copy = SynGraph(H.raw.copy()) # Reset atom-map to zero for all nodes for node in G_copy.raw.nodes(): G_copy.raw.nodes[node][aam_key] = 0 for node in H_copy.raw.nodes(): H_copy.raw.nodes[node][aam_key] = 0 # Assign new sequential mapping values for i, (g_node, h_node) in enumerate(mapping): value = i + 1 if g_node in G_copy.raw: G_copy.raw.nodes[g_node][aam_key] = value else: logger.warning(f"Node {g_node} not found in reactant graph") if h_node in H_copy.raw: H_copy.raw.nodes[h_node][aam_key] = value else: logger.warning(f"Node {h_node} not found in product graph") return G_copy, H_copy
[docs] def fit(self, rsmi: str) -> str: """Generate an extended reaction SMILES by aligning atom maps of reactant and product. :param rsmi: Reaction SMILES string formatted as 'reactant>>product'. :type rsmi: str :returns: Extended reaction SMILES after remapping. :rtype: str :raises ValueError: If input format is invalid or graphs are not isomorphic. :example: >>> its = ITSRelabel() >>> its.fit('CCO:1>>CC=O:1') 'CCO>>CC=O' """ # Parse reaction SMILES try: react_smiles, prod_smiles = rsmi.split(">>") except ValueError as e: raise ValueError("Expected 'reactant>>product' format") from e # Convert to SynGraphs G = SynGraph( smiles_to_graph( react_smiles, drop_non_aam=False, use_index_as_atom_map=False ) ) H = SynGraph( smiles_to_graph( prod_smiles, drop_non_aam=False, use_index_as_atom_map=False ) ) # Identify and prune mapped nodes R_nodes = self._get_nodes_with_atom_map(G) P_nodes = self._get_nodes_with_atom_map(H) Gp = self._remove_internal_edges(G, R_nodes) Hp = self._remove_internal_edges(H, P_nodes) # Find isomorphism mapping mapping = find_graph_isomorphism(Gp.raw, Hp.raw, use_defaults=True) if not mapping: raise ValueError("No isomorphism found between pruned graphs") # Update atom_map attributes mapped_pairs = self._dict_to_tuple_list(mapping, sort_by_key=True) G_new, H_new = self._update_mapping(G, H, mapped_pairs, aam_key="atom_map") # Convert back to molecules and SMILES mol_G = self.graph_to_mol.graph_to_mol(G_new.raw, use_h_count=True) mol_H = self.graph_to_mol.graph_to_mol(H_new.raw, use_h_count=True) return f"{Chem.MolToSmiles(mol_G)}>>{Chem.MolToSmiles(mol_H)}"