Source code for synkit.Graph.Hyrogen.hextend

import networkx as nx
from joblib import Parallel, delayed
from typing import List, Tuple, Dict

from synkit.Graph.Matcher.graph_cluster import GraphCluster
from synkit.Graph.Hyrogen.hcomplete import HComplete

from synkit.Graph.Feature.wl_hash import WLHash
from synkit.Graph.Hyrogen._misc import check_hcount_change
from synkit.Graph.ITS.its_decompose import get_rc, its_decompose

cluster = GraphCluster()


[docs] class HExtend(HComplete):
[docs] @staticmethod def get_unique_graphs_for_clusters( graphs: List[nx.Graph], cluster_indices: List[set] ) -> List[nx.Graph]: """Retrieve a unique graph for each cluster from a list of graphs based on cluster indices. This method selects one graph per cluster based on the first index found in each cluster set. Note: Clusters are expected to be represented as sets of indices, each corresponding to a graph in the `graphs` list. Parameters: - graphs (List[nx.Graph]): List of networkx graphs. - cluster_indices (List[set]): List of sets, each containing indices representing graphs that belong to the same cluster. Returns: - List[nx.Graph]: A list containing one unique graph from each cluster. The graph chosen is the one corresponding to the first index in each cluster set, which is arbitrary due to the unordered nature of sets. Raises: - ValueError: If any index in `cluster_indices` is out of the range of `graphs`. - TypeError: If `cluster_indices` is not a list of sets. """ if not all(isinstance(cluster, set) for cluster in cluster_indices): raise TypeError("Each cluster index must be a set of integers.") if any( min(cluster) < 0 or max(cluster) >= len(graphs) for cluster in cluster_indices if cluster ): raise ValueError("Cluster indices are out of the range of the graphs list.") unique_graphs = [ graphs[next(iter(cluster))] for cluster in cluster_indices if cluster ] return unique_graphs
@staticmethod def _extend( its: nx.Graph, ignore_aromaticity: bool, balance_its: bool, ) -> Tuple[List[nx.Graph], List[nx.Graph], List[str]]: """Process equivalent maps by adding hydrogen nodes and constructing ITS graphs based on the balance and aromaticity settings. Parameters: - its (nx.Graph): The initial transition state graph to be processed. - ignore_aromaticity (bool): Flag to ignore aromaticity in graph construction. - balance_its (bool): Flag to balance the ITS graph during processing. Returns: - Tuple[List[nx.Graph], List[nx.Graph], List[str]]: Tuple containing lists of processed reaction graphs, ITS graphs, and their signatures. """ react_graph, prod_graph = its_decompose(its) hcount_change = check_hcount_change(react_graph, prod_graph) if hcount_change == 0: its_list = [its] rc_list = [get_rc(its)] sigs = [ WLHash(iterations=3).weisfeiler_lehman_graph_hash(i) for i in rc_list ] return rc_list, its_list, sigs combinations_solution = HComplete.add_hydrogen_nodes_multiple( react_graph, prod_graph, ignore_aromaticity, balance_its, get_priority_graph=True, ) rc_list, its_list, rc_sig = [], [], [] for _, _, its, rc, sig in combinations_solution: if rc and isinstance(rc, nx.Graph) and rc.number_of_nodes() > 0: rc_list.append(rc) its_list.append(its) rc_sig.append(sig) return rc_list, its_list, rc_sig @staticmethod def _process( data_dict: Dict, its_key: str, rc_key: str, ignore_aromaticity: bool, balance_its: bool, ) -> Dict: """Processes a dictionary of graphs using specific graph processing functions and updates the dictionary with new graph data. Parameters: - data_dict (Dict): Dictionary containing the graphs and their keys. - its_key (str): Key in the dictionary for the ITS graph. - rc_key (str): Key in the dictionary for the reaction graph. - ignore_aromaticity (bool): Whether to ignore aromaticity during graph processing. - balance_its (bool): Whether to balance the ITS graph. Returns: - Dict: The updated dictionary containing new ITS and reaction graphs. """ its = data_dict[its_key] rc_list, its_list, rc_sig = HExtend._extend( its, ignore_aromaticity, balance_its ) cls, _ = cluster.iterative_cluster(rc_list, rc_sig) new_rc = HExtend.get_unique_graphs_for_clusters(rc_list, cls) new_its = HExtend.get_unique_graphs_for_clusters(its_list, cls) data_dict[rc_key] = new_rc data_dict[its_key] = new_its return data_dict
[docs] @staticmethod def fit( data, its_key: str, rc_key: str, ignore_aromaticity: bool = False, balance_its: bool = True, n_jobs: int = 1, verbose: int = 0, ) -> List: """Fit the model to the data in parallel, processing each entry to generate new graph data based on the ITS and reaction graph keys. Parameters: - data (iterable): Data to be processed. - its_key (str): Key for the ITS graphs in the data. - rc_key (str): Key for the reaction graphs in the data. - ignore_aromaticity (bool): Whether to ignore aromaticity during processing. Default to False. - balance_its (bool): Whether to balance the ITS during processing. Default to True. - n_jobs (int): Number of jobs to run in parallel. Default to 1. - verbose (int): Verbosity level for parallel processing. Default to 0. Returns: - List: A list containing the results of the processed data. """ results = Parallel(n_jobs=n_jobs, verbose=verbose, backend="multiprocessing")( delayed(HExtend._process)( item, its_key, rc_key, ignore_aromaticity, balance_its ) for item in data ) return results