Source code for synkit.Graph.FG.model
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable, Iterable
import networkx as nx
Mapping = dict[int, int]
Validator = Callable[[nx.Graph, Mapping], bool]
Recognizer = Callable[
[nx.Graph, "FunctionalGroupPattern"], list["FunctionalGroupMatch"]
]
[docs]
@dataclass(frozen=True)
class FunctionalGroupPattern:
"""Graph-native functional-group definition."""
name: str
graph: nx.Graph
group_nodes: tuple[int, ...]
parents: tuple[str, ...] = ()
suppresses: tuple[str, ...] = ()
requires: tuple[str, ...] = ()
anchor_node: int | None = None
priority: int = 0
validator: Validator | None = None
recognizer: Recognizer | None = None
public: bool = True
[docs]
@dataclass(frozen=True)
class FunctionalGroupMatch:
"""One matched functional group in a host graph."""
name: str
group_nodes: tuple[int, ...]
mapping: Mapping
pattern: FunctionalGroupPattern
[docs]
@dataclass
class FunctionalGroupRegistry:
"""Container for functional-group patterns and hierarchy metadata."""
patterns: list[FunctionalGroupPattern] = field(default_factory=list)
[docs]
def add(self, pattern: FunctionalGroupPattern) -> None:
self.patterns.append(pattern)
[docs]
def extend(self, patterns: Iterable[FunctionalGroupPattern]) -> None:
self.patterns.extend(patterns)
[docs]
def by_name(self, name: str) -> FunctionalGroupPattern:
for pattern in self.patterns:
if pattern.name == name:
return pattern
raise KeyError(name)
[docs]
def is_ancestor(self, ancestor: str, child: str) -> bool:
"""Return whether ``ancestor`` is an ancestor of ``child``."""
seen: set[str] = set()
stack = [child]
while stack:
current = stack.pop()
if current in seen:
continue
seen.add(current)
try:
parents = self.by_name(current).parents
except KeyError:
parents = ()
for parent in parents:
if parent == ancestor:
return True
stack.append(parent)
return False
[docs]
def execution_order(self) -> list[FunctionalGroupPattern]:
"""Return patterns in prerequisite-respecting order."""
by_name = {pattern.name: pattern for pattern in self.patterns}
visited: set[str] = set()
ordered: list[FunctionalGroupPattern] = []
def visit(name: str) -> None:
if name in visited:
return
visited.add(name)
pattern = by_name[name]
for required in pattern.requires:
if required in by_name:
visit(required)
ordered.append(pattern)
for pattern in self.patterns:
visit(pattern.name)
return ordered