Source code for synkit.Graph.Matcher.batch_cluster

import networkx as nx
from operator import eq
from typing import List, Dict, Any, Tuple, Optional, Callable
from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match
from synkit.Utils.utils import stratified_random_sample
from synkit.Rule.Modify.rule_utils import strip_context
from synkit.Graph.Matcher.graph_cluster import GraphCluster
from synkit.Graph.Matcher.graph_morphism import graph_isomorphism

# from synkit.Graph.Matcher.rule_morphism import rule_isomorphism
from synkit.Graph.Matcher.graph_matcher import GraphMatcherEngine


[docs] class BatchCluster: def __init__( self, node_label_names: List[str] = ["element", "charge"], node_label_default: List[Any] = ["*", 0], edge_attribute: str = "order", backend: str = "nx", ): """Initializes an AutoCat instance which uses isomorphism checks for categorizing new graphs or rules. Parameters: - node_label_names (List[str]): Names of the node attributes to use in isomorphism checks. - node_label_default (List[Any]): Default values for node attributes if they are missing in the graph data. - edge_attribute (str): The edge attribute to consider when checking isomorphism between graphs. Raises: - ValueError: If the lengths of `node_label_names` and `node_label_default` do not match. """ self.backend = backend.lower() available = self.available_backends() if self.backend not in available: if self.backend == "rule": raise ImportError("MOD is not installed") raise ValueError(f"Unsupported backend: {backend!r}") if len(node_label_names) != len(node_label_default): raise ValueError( "The lengths of `node_label_names` and `node_label_default` must match." ) if backend == "nx": self.nodeLabelNames = node_label_names self.nodeLabelDefault = node_label_default self.edgeAttribute = edge_attribute self.nodeMatch = generic_node_match( self.nodeLabelNames, self.nodeLabelDefault, [eq] * len(node_label_names) ) self.edgeMatch = generic_edge_match(edge_attribute, 1, eq)
[docs] def available_backends(self) -> List[str]: """ Return available backends: always includes 'nx'; adds 'rule' if the 'mod' package is installed. """ import importlib.util backends = ["nx"] # Check if 'mod' package is importable without executing it if importlib.util.find_spec("mod") is not None: backends.append("rule") return backends
[docs] def lib_check( self, data: Dict, templates: List[Dict], rule_key: str = "gml", attribute_key: str = "signature", nodeMatch: Optional[Callable] = None, edgeMatch: Optional[Callable] = None, ) -> Dict: """Checks and classifies a graph or rule based on existing templates using either graph or rule isomorphism. Parameters: - data (Dict): A dictionary representing a graph or rule with its attributes and classification. - templates (List[Dict]): Dynamic templates used for categorization. If None, initializes to an empty list. - rule_key (str): Key to access the graph or rule data within the dictionary. - attribute_key (str): An attribute used to filter templates before isomorphism check. - nodeMatch (Optional[Callable]): A function to match nodes, defaults to a predefined generic_node_match. - edgeMatch (Optional[Callable]): A function to match edges, defaults to a predefined generic_edge_match. Returns: - Dict: The updated dictionary with its classification. """ # Ensure that templates are not None if templates is None: templates = [] att = data.get(attribute_key) sub_temp = [temp for temp in templates if temp.get(attribute_key) == att] for template in sub_temp: template_data = ( strip_context(template[rule_key]) if isinstance(template[rule_key], str) else template[rule_key] ) data_rule = ( strip_context(data[rule_key]) if isinstance(data[rule_key], str) else data[rule_key] ) if isinstance(data_rule, str): iso_function = GraphMatcherEngine()._isomorphic_rule apply_match_args = False elif isinstance(data_rule, nx.Graph): iso_function = graph_isomorphism apply_match_args = True if apply_match_args: if iso_function( template_data, data_rule, nodeMatch or self.nodeMatch, edgeMatch or self.edgeMatch, ): data["class"] = template["class"] break else: if iso_function(template_data, data_rule): data["class"] = template["class"] break else: new_class = max((temp["class"] for temp in templates), default=-1) + 1 data["class"] = new_class templates.append(data.copy()) # Append a copy to avoid reference issues return data, templates
[docs] @staticmethod def batch_dicts(input_list, batch_size): """Splits a list of dictionaries into batches of a specified size. Args: input_list (list of dict): The list of dictionaries to be batched. batch_size (int): The size of each batch. Returns: list of list of dict: A list where each element is a batch (sublist) of dictionaries. Raises: ValueError: If batch_size is less than 1. """ # Validate batch_size to ensure it's a positive integer if batch_size < 1: raise ValueError("batch_size must be at least 1") # Initialize an empty list to hold the batches batches = [] # Iterate over the input list in steps of batch_size for i in range(0, len(input_list), batch_size): # Append a batch slice to the batches list # fmt: off batches.append(input_list[i: i + batch_size]) # fmt: on return batches
[docs] def cluster( self, data: List[Dict], templates: List[Dict], rule_key: str = "gml", attribute_key: str = "WLHash", ) -> Tuple[List[Dict], List[Dict]]: """Processes a list of graph data entries, classifying each based on existing templates. Parameters: - data (List[Dict]): A list of dictionaries, each representing a graph or rule to be classified. - templates (List[Dict]): Dynamic templates used for categorization. Returns: - Tuple[List[Dict], List[Dict]]: A tuple containing the list of classified data and the updated templates. """ for entry in data: _, templates = self.lib_check(entry, templates, rule_key, attribute_key) return data, templates
[docs] def fit( self, data: List[Dict], templates: List[Dict], rule_key: str = "gml", attribute_key: str = "WLHash", batch_size: Optional[int] = None, ) -> Tuple[List[Dict], List[Dict]]: """Processes and classifies data in batches. Uses GraphCluster for initial processing and a stratified sampling technique to update templates if there is only one batch and no initial templates are provided. Parameters: - data (List[Dict]): Data to process. - templates (List[Dict]): Templates for categorization. - rule_key (str): Key to access rule or graph data. - attribute_key (str): Key to access attributes used for filtering. - batch_size (Optional[int]): Size of batches for processing, if not provided, processes all data at once. Returns: - Tuple[List[Dict], List[Dict]]: The processed data and the potentially updated templates. """ if batch_size is not None: batches = self.batch_dicts(data, batch_size) else: batches = [data] # Process all at once if no batch size provided output_data, output_templates = [], templates if templates is not None else [] graph_cluster = GraphCluster() if len(batches) == 1: batch = batches[0] if not templates: output_data = graph_cluster.fit(batch, rule_key, attribute_key) output_templates = stratified_random_sample( output_data, property_key="class", samples_per_class=1, seed=1 ) else: output_data, output_templates = self.cluster( batch, output_templates, rule_key, attribute_key ) else: for batch in batches: processed_data, new_templates = self.cluster( batch, output_templates, rule_key, attribute_key ) output_data.extend(processed_data) output_templates = new_templates return output_data, output_templates