Source code for synkit.CRN.Construct.flattener
from __future__ import annotations
from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
import networkx as nx
[docs]
@dataclass
class ReactionDeltaFlattener:
graph: nx.DiGraph
skip_no_change: bool = True
allow_empty_side: bool = False
deduplicate: bool = True
_cache: List[Dict[str, Any]] = field(default_factory=list, init=False)
[docs]
def build(self) -> "ReactionDeltaFlattener":
self._cache = self._flatten()
return self
@property
def reactions(self) -> List[Dict[str, Any]]:
return list(self._cache)
def _collect_in(self, eid: int) -> List[str]:
out: List[str] = []
for u, _, ed in self.graph.in_edges(eid, data=True):
if ed.get("role") != "reactant":
continue
if self.graph.nodes[u].get("kind") != "species":
continue
stoich = int(ed.get("stoich", 1))
out.extend([self.graph.nodes[u].get("smiles", str(u))] * stoich)
return out
def _collect_out(self, eid: int) -> List[str]:
out: List[str] = []
for _, v, ed in self.graph.out_edges(eid, data=True):
if ed.get("role") != "product":
continue
if self.graph.nodes[v].get("kind") != "species":
continue
stoich = int(ed.get("stoich", 1))
out.extend([self.graph.nodes[v].get("smiles", str(v))] * stoich)
return out
def _nz(self, x: Optional[int]) -> int:
return 10**9 if x is None else int(x)
def _flatten(self) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
seen: Set[str] = set()
for eid, nd in self.graph.nodes(data=True):
if nd.get("kind") != "rule":
continue
r_all = sorted(self._collect_in(eid))
p_all = sorted(self._collect_out(eid))
r_counter = Counter(r_all)
p_counter = Counter(p_all)
common = r_counter & p_counter
rchg = sorted(list((r_counter - common).elements()))
pchg = sorted(list((p_counter - common).elements()))
if self.skip_no_change and not rchg and not pchg:
continue
if (not self.allow_empty_side) and (not rchg or not pchg):
continue
rxn_smiles = f"{'.'.join(rchg)}>>{'.'.join(pchg)}"
if self.deduplicate and rxn_smiles in seen:
continue
seen.add(rxn_smiles)
out.append(
{
"rule_id": eid,
"label": nd.get("label"),
"step": nd.get("step"),
"rule_index": nd.get("rule_index"),
"app_index": nd.get("app_index"),
"reactants": rchg,
"products": pchg,
"rule_smiles": rxn_smiles,
}
)
out.sort(
key=lambda r: (
self._nz(r.get("step")),
self._nz(r.get("rule_index")),
self._nz(r.get("app_index")),
r["rule_id"],
)
)
return out