Source code for synkit.Rule.Modify.prune_templates

import networkx as nx
from copy import deepcopy
from synkit.Rule.Modify.longest_path import LongestPath
from typing import List, Dict, Any


[docs] class PruneTemplate: def __init__(self, templates: List[List[Dict[str, Any]]], graph_key: str) -> None: """Initialize the PruneTemplate object with the provided templates and graph key. Parameters: - templates (List[List[Dict[str, Any]]]): A list of lists containing dictionaries where the graph can be accessed by the provided graph_key. - graph_key (str): The key used to access the graph from each template dictionary. """ self.max_radius = len(templates) self.templates = deepcopy(templates) self.graph_key = graph_key
[docs] @staticmethod def remove_edges_by_attribute( input_graph: nx.Graph, attribute: str = "standard_order", value: Any = 0 ) -> nx.Graph: """Remove edges from the input graph where a given attribute equals a specified value. Parameters: - input_graph (nx.Graph): The input graph from which edges will be removed. - attribute (str, optional): The edge attribute based on which edges will be removed. Default is 'standard_order'. - value (Any, optional): The value of the attribute that determines which edges to remove. Default is 0. Returns: nx.Graph: A new graph with the specified edges removed. """ # Find edges where the specified attribute equals the given value graph = deepcopy(input_graph) edges_to_remove = [ (u, v) for u, v, attrs in graph.edges(data=True) if attrs.get(attribute) != value ] graph.remove_edges_from(edges_to_remove) return graph
[docs] def fit(self) -> List[List[Dict[str, Any]]]: """Prune the templates by removing subgraphs where the longest path is shorter than the radius. Returns: List[List[Dict[str, Any]]]: The pruned list of templates. """ for radius, template in enumerate(self.templates): if radius > 0: for key in reversed(range(len(template))): temp = template[key] subgraph = temp.get(self.graph_key, None)[2] if subgraph is None: continue pruned_graph = PruneTemplate.remove_edges_by_attribute(subgraph) path_calculator = LongestPath(pruned_graph) longest_path = path_calculator.LongestPathInDisconnectedGraph() if longest_path < radius: template.pop(key) return self.templates