Source code for synkit.Graph.Hyrogen.hcomplete

import itertools
import networkx as nx
from copy import deepcopy, copy
from joblib import Parallel, delayed
from typing import Dict, List, Tuple, Iterable, Optional

from synkit.IO.debug import setup_logging
from synkit.Graph.Feature.wl_hash import WLHash
from synkit.Graph.ITS.its_construction import ITSConstruction
from synkit.Graph.ITS.its_decompose import get_rc, its_decompose
from synkit.Graph.Hyrogen._misc import (
    check_hcount_change,
    check_explicit_hydrogen,
    get_priority,
    check_equivariant_graph,
)

logger = setup_logging()


[docs] class HComplete: """A class for infering hydrogen to complete reaction center or ITS graph."""
[docs] @staticmethod def process_single_graph_data( graph_data: Dict[str, nx.Graph], its_key: str = "ITS", rc_key: str = "RC", ignore_aromaticity: bool = False, balance_its: bool = True, get_priority_graph: bool = False, max_hydrogen: int = 7, ) -> Dict[str, Optional[nx.Graph]]: """Processes a single graph data dictionary by modifying hydrogen counts and other features based on configuration settings. Parameters: - graph_data (Dict[str, nx.Graph]): Dictionary containing the graph data. - its_key (str): Key where the ITS graph is stored. - rc_key (str): Key where the RC graph is stored. - ignore_aromaticity (bool): If True, aromaticity is ignored during processing. Default is False. - balance_its (bool): If True, the ITS is balanced. Default is True. - get_priority_graph (bool): If True, priority is given to graph data during processing. Default is False. - max_hydrogen (int): Maximum number of hydrogens that can be handled in the inference step. Returns: - Dict[str, Optional[nx.Graph]]: Dictionary with updated ITS and RC graph data, or None if processing fails. """ graphs = copy(graph_data) its = graphs.get(its_key, None) if not isinstance(its, nx.Graph) or its.number_of_nodes() == 0: graphs[its_key], graphs[rc_key] = None, None return graphs react_graph, prod_graph = its_decompose(its) hcount_change = check_hcount_change(react_graph, prod_graph) if hcount_change == 0: graphs = graphs elif hcount_change <= max_hydrogen: graphs = HComplete.process_multiple_hydrogens( graphs, its_key, rc_key, react_graph, prod_graph, ignore_aromaticity, balance_its, get_priority_graph, ) else: graphs[its_key], graphs[rc_key] = None, None if graphs[rc_key] is not None: is_empty_rc_present = ( not isinstance(graphs[rc_key], nx.Graph) or graphs[rc_key].number_of_nodes() == 0 ) if is_empty_rc_present: graphs[its_key] = None graphs[rc_key] = None return graphs
[docs] def process_graph_data_parallel( self, graph_data_list: List[Dict[str, nx.Graph]], its_key: str = "ITS", rc_key: str = "RC", n_jobs: int = 1, verbose: int = 0, ignore_aromaticity: bool = False, balance_its: bool = True, get_priority_graph: bool = False, max_hydrogen: int = 7, ) -> List[Dict[str, Optional[nx.Graph]]]: """Processes a list of graph data dictionaries in parallel to optimize the hydrogen completion and other graph modifications. Parameters: - graph_data_list (List[Dict[str, nx.Graph]]): List of dictionaries containing the graph data. - its_key (str): Key where the ITS graph is stored. - rc_key (str): Key where the RC graph is stored. - n_jobs (int): Number of parallel jobs to run. - verbose (int): Verbosity level for the parallel process. - ignore_aromaticity (bool): If True, aromaticity is ignored during processing. Default is False. - balance_its (bool): If True, the ITS is balanced. Default is True. - get_priority_graph (bool): If True, priority is given to graph data during processing. Default is False. - max_hydrogen (int): Maximum number of hydrogens that can be handled in the inference step. Returns: - List[Dict[str, Optional[nx.Graph]]]: List of dictionaries with updated ITS and RC graph data, or None if processing fails. """ processed_data = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(self.process_single_graph_data)( graph_data, its_key, rc_key, ignore_aromaticity, balance_its, get_priority_graph, max_hydrogen, ) for graph_data in graph_data_list ) return processed_data
[docs] @staticmethod def process_multiple_hydrogens( graph_data: Dict[str, nx.Graph], its_key: str, rc_key: str, react_graph: nx.Graph, prod_graph: nx.Graph, ignore_aromaticity: bool, balance_its: bool, get_priority_graph: bool = False, ) -> Dict[str, Optional[nx.Graph]]: """Handles significant hydrogen count changes between reactant and product graphs, adjusting hydrogen nodes accordingly and assessing graph equivalence. Parameters: - graph_data (Dict[str, nx.Graph]): Dictionary containing the graph data. - its_key (str): Key for the ITS graph in the dictionary. - rc_key (str): Key for the RC graph in the dictionary. - react_graph (nx.Graph): Graph representing the reactants. - prod_graph (nx.Graph): Graph representing the products. - ignore_aromaticity (bool): If True, aromaticity will not be considered in processing. - balance_its (bool): If True, balances the ITS graph. - get_priority_graph (bool): If True, processes graphs with priority considerations. Returns: - Dict[str, Optional[nx.Graph]]: Updated graph dictionary with potentially modified ITS and RC graphs. """ combinations_solution = HComplete.add_hydrogen_nodes_multiple( react_graph, prod_graph, ignore_aromaticity, balance_its, get_priority_graph, ) if len(combinations_solution) == 0: graph_data[its_key], graph_data[rc_key] = None, None return graph_data filtered_combinations_solution = [] react_list = [] prod_list = [] rc_list = [] its_list = [] rc_sig = [] for react, prod, its, rc, sig in combinations_solution: if rc is not None and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0: filtered_combinations_solution.append((react, prod, rc, its, sig)) react_list.append(react) prod_list.append(prod) rc_list.append(rc) its_list.append(its) rc_sig.append(sig) if len(set(rc_sig)) != 1: equivariant = 0 else: _, equivariant = check_equivariant_graph(rc_list) pairwise_combinations = len(rc_list) - 1 if equivariant == pairwise_combinations: graph_data[its_key] = its_list[0] graph_data[rc_key] = rc_list[0] else: graph_data[its_key], graph_data[rc_key] = None, None if get_priority_graph: priority_indices = get_priority(rc_list) rc_list = [rc_list[i] for i in priority_indices] rc_sig = [rc_sig[i] for i in priority_indices] its_list = [its_list[i] for i in priority_indices] react_list = [react_list[i] for i in priority_indices] prod_list = [prod_list[i] for i in priority_indices] if len(set(rc_sig)) == 1: _, equivariant = check_equivariant_graph(rc_list) pairwise_combinations = len(rc_list) - 1 if equivariant == pairwise_combinations: graph_data[its_key] = its_list[0] graph_data[rc_key] = rc_list[0] return graph_data
[docs] @staticmethod def add_hydrogen_nodes_multiple( react_graph: nx.Graph, prod_graph: nx.Graph, ignore_aromaticity: bool, balance_its: bool, get_priority_graph: bool = False, ) -> List[Tuple[nx.Graph, nx.Graph]]: """Generates multiple permutations of reactant and product graphs by adjusting hydrogen counts, exploring all possible configurations of hydrogen node additions or removals. Parameters: - react_graph (nx.Graph): The reactant graph. - prod_graph (nx.Graph): The product graph. - ignore_aromaticity (bool): If True, aromaticity is ignored. - balance_its (bool): If True, attempts to balance the ITS by adjusting hydrogen nodes. - get_priority_graph (bool): If True, additional priority-based processing is applied to select optimal graph configurations. Returns: - List[Tuple[nx.Graph, nx.Graph]]: A list of graph tuples, each representing a possible configuration of reactant and product graphs with adjusted hydrogen nodes. """ react_graph_copy = react_graph.copy() prod_graph_copy = prod_graph.copy() react_explicit_h, hydrogen_nodes = check_explicit_hydrogen(react_graph_copy) prod_explicit_h, _ = check_explicit_hydrogen(prod_graph_copy) hydrogen_nodes_form, hydrogen_nodes_break = [], [] primary_graph = ( react_graph_copy if react_explicit_h <= prod_explicit_h else prod_graph_copy ) for node_id in primary_graph.nodes: try: # Calculate the difference in hydrogen counts hcount_diff = react_graph_copy.nodes[node_id].get( "hcount", 0 ) - prod_graph_copy.nodes[node_id].get("hcount", 0) except KeyError: # Handle cases where node_id does not exist in opposite_graph continue # Decide action based on hcount_diff if hcount_diff > 0: hydrogen_nodes_break.extend([node_id] * hcount_diff) elif hcount_diff < 0: hydrogen_nodes_form.extend([node_id] * -hcount_diff) max_index = max( max(react_graph_copy.nodes, default=0), max(prod_graph_copy.nodes, default=0), ) range_implicit_h = range( max_index + 1, max_index + 1 + len(hydrogen_nodes_form) - react_explicit_h, ) combined_indices = list(range_implicit_h) + hydrogen_nodes permutations = list(itertools.permutations(combined_indices)) permutations_seed = permutations[0] updated_graphs = [] for permutation in permutations: current_react_graph, current_prod_graph = react_graph_copy, prod_graph_copy new_hydrogen_node_ids = [i for i in permutations_seed] # Use `zip` to pair `hydrogen_nodes_break` with the new IDs node_id_pairs = zip(hydrogen_nodes_break, new_hydrogen_node_ids) # Call the method with the formed pairs and specify atom_map_update as False current_react_graph = HComplete.add_hydrogen_nodes_multiple_utils( current_react_graph, node_id_pairs, atom_map_update=False ) # Varied hydrogen nodes in the product graph based on permutation current_prod_graph = HComplete.add_hydrogen_nodes_multiple_utils( current_prod_graph, zip(hydrogen_nodes_form, permutation) ) its = ITSConstruction().ITSGraph( current_react_graph, current_prod_graph, ignore_aromaticity=ignore_aromaticity, balance_its=balance_its, ) rc = get_rc(its) sig = WLHash(iterations=3).weisfeiler_lehman_graph_hash(rc) if get_priority_graph is False: if len(updated_graphs) > 0: if sig != updated_graphs[-1][-1]: return [] updated_graphs.append( (current_react_graph, current_prod_graph, its, rc, sig) ) return updated_graphs
[docs] @staticmethod def add_hydrogen_nodes_multiple_utils( graph: nx.Graph, node_id_pairs: Iterable[Tuple[int, int]], atom_map_update: bool = True, ) -> nx.Graph: """Creates and returns a new graph with added hydrogen nodes based on the input graph and node ID pairs. Parameters: - graph (nx.Graph): The base graph to which the nodes will be added. - node_id_pairs (Iterable[Tuple[int, int]]): Pairs of node IDs (original node, new hydrogen node) to link with hydrogen. - atom_map_update (bool): If True, update the 'atom_map' attribute with the new hydrogen node ID; otherwise, retain the original node's 'atom_map'. Returns: - nx.Graph: A new graph instance with the added hydrogen nodes. """ new_graph = deepcopy(graph) for node_id, new_hydrogen_node_id in node_id_pairs: atom_map_val = ( new_hydrogen_node_id if atom_map_update else new_graph.nodes[node_id].get("atom_map", 0) ) new_graph.add_node( new_hydrogen_node_id, charge=0, hcount=0, aromatic=False, element="H", atom_map=atom_map_val, # isomer="N", # partial_charge=0, # hybridization=0, # in_ring=False, # explicit_valence=0, # implicit_hcount=0, ) new_graph.add_edge( node_id, new_hydrogen_node_id, order=1.0, # ez_isomer="N", bond_type="SINGLE", # conjugated=False, # in_ring=False, ) new_graph.nodes[node_id]["hcount"] -= 1 return new_graph