Source code for synkit.Chem.Molecule.graph_annotator

from __future__ import annotations

import logging
from typing import Any, Dict, List, Set

import networkx as nx

logger = logging.getLogger(__name__)

Node = Any
NodeData = Dict[str, Any]


[docs] class GraphAnnotator: """ Compute optional topology annotations for a NetworkX molecular graph. The annotator mutates a graph (in-place by default) or a shallow copy if `in_place=False`. Methods are chainable and return ``self``; use the ``.graph`` property to retrieve the annotated graph. Supported annotations: - node degree -> ``atom_degree`` - neighborhood element counts -> ``nbr_elements_counts_r1`` - shortest distances to motif sets (halogen/hetero/aromatic/carbonyl) -> ``dist_to_<motif>`` - conjugated/pi connected component size -> ``conj_component_size`` - ring sizes via cycle basis -> ``ring_sizes`` and updated ``in_ring`` Notes ----- * Graph nodes are expected to carry an ``'element'`` key (string) and may optionally have ``'aromatic'`` or ``'is_halogen'`` boolean flags. * Edge attributes used: ``'order'`` (numeric-like) and ``'conjugated'``. """ DEFAULT_MAX_DISTANCE = 99 def __init__( self, G: nx.Graph, in_place: bool = True, max_distance: int = DEFAULT_MAX_DISTANCE, ): self._original = G self._in_place = bool(in_place) self._max_distance = int(max_distance) self._graph: nx.Graph = G if in_place else G.copy()
[docs] def annotate(self) -> "GraphAnnotator": self._degrees() self._neighbor_counts() self._rings_via_cycle_basis() self._distances_to_motifs() self._pi_components() return self
@property def graph(self) -> nx.Graph: return self._graph
[docs] def get_node(self, n: Node) -> NodeData: return self._graph.nodes[n]
[docs] def get_node_attr(self, n: Node, attr: str, default: Any = None) -> Any: return self._graph.nodes[n].get(attr, default)
def _degrees(self) -> "GraphAnnotator": try: for n in self._graph.nodes(): self._graph.nodes[n]["atom_degree"] = int(self._graph.degree(n)) except Exception as exc: logger.debug("Failed to annotate degrees: %s", exc, exc_info=True) return self def _neighbor_counts(self) -> "GraphAnnotator": try: for n in self._graph.nodes(): counts: Dict[str, int] = {} for nb in self._graph.neighbors(n): el = self._graph.nodes[nb].get("element", "C") counts[el] = counts.get(el, 0) + 1 self._graph.nodes[n]["nbr_elements_counts_r1"] = counts except Exception as exc: logger.debug("Failed to compute neighbor counts: %s", exc, exc_info=True) return self def _rings_via_cycle_basis(self) -> "GraphAnnotator": try: cycles = nx.cycle_basis(self._graph) per_node: Dict[Node, List[int]] = {} for cyc in cycles: size = len(cyc) for node in cyc: per_node.setdefault(node, []).append(size) for n in self._graph.nodes(): sizes = per_node.get(n, []) self._graph.nodes[n].setdefault("ring_sizes", sizes) self._graph.nodes[n].setdefault( "in_ring", bool(sizes or self._graph.nodes[n].get("in_ring", False)) ) except Exception as exc: logger.debug( "Failed to compute ring_sizes via cycle_basis: %s", exc, exc_info=True ) for n in self._graph.nodes(): self._graph.nodes[n].setdefault( "ring_sizes", self._graph.nodes[n].get("ring_sizes", []) ) self._graph.nodes[n].setdefault( "in_ring", bool(self._graph.nodes[n].get("in_ring", False)) ) return self def _pi_components(self) -> "GraphAnnotator": try: PG = nx.Graph() for u, v, ed in self._graph.edges(data=True): try: order = float(ed.get("order", 1.0)) except Exception: order = 1.0 conj = bool(ed.get("conjugated", False)) arom = bool( self._graph.nodes[u].get("aromatic", False) and self._graph.nodes[v].get("aromatic", False) ) if order >= 2.0 or conj or arom: PG.add_edge(u, v) for n, d in self._graph.nodes(data=True): if d.get("aromatic", False) and n not in PG: PG.add_node(n) comp_size: Dict[Node, int] = {} for comp in nx.connected_components(PG): s = len(comp) for node in comp: comp_size[node] = s for n in self._graph.nodes(): self._graph.nodes[n]["conj_component_size"] = int(comp_size.get(n, 0)) except Exception as exc: logger.debug("Failed to compute pi components: %s", exc, exc_info=True) for n in self._graph.nodes(): self._graph.nodes[n].setdefault("conj_component_size", 0) return self # --- New small motif collectors (each purposely tiny) ----------------- def _collect_halogens(self) -> Set[Node]: try: return { n for n, d in self._graph.nodes(data=True) if d.get("element") in {"F", "Cl", "Br", "I"} or d.get("is_halogen") } except Exception: logger.debug("Failed to collect halogens", exc_info=True) return set() def _collect_hetero(self) -> Set[Node]: try: return { n for n, d in self._graph.nodes(data=True) if d.get("element") not in {"C", "H"} } except Exception: logger.debug("Failed to collect heteroatoms", exc_info=True) return set() def _collect_aromatic(self) -> Set[Node]: try: return { n for n, d in self._graph.nodes(data=True) if d.get("aromatic", False) } except Exception: logger.debug("Failed to collect aromatic nodes", exc_info=True) return set() def _collect_carbonyl(self) -> Set[Node]: carbonyl: Set[Node] = set() try: for u, v, ed in self._graph.edges(data=True): try: order = float(ed.get("order", 1.0)) except Exception: order = 1.0 elems = { self._graph.nodes[u].get("element"), self._graph.nodes[v].get("element"), } if order >= 2.0 and elems == {"C", "O"}: if self._graph.nodes[u].get("element") == "C": carbonyl.add(u) if self._graph.nodes[v].get("element") == "C": carbonyl.add(v) except Exception: logger.debug("Failed to collect carbonyls", exc_info=True) return set() return carbonyl def _collect_motif_sets(self) -> Dict[str, Set[Node]]: """ Collect node sets for motifs: halogen, hetero, aromatic, carbonyl. Delegates to micro-helpers to keep function complexity low. """ return { "halogen": self._collect_halogens(), "hetero": self._collect_hetero(), "aromatic": self._collect_aromatic(), "carbonyl": self._collect_carbonyl(), } def _compute_distances_for_motif( self, sources: Set[Node], cutoff: int ) -> Dict[Node, int]: if not sources: return {} try: return dict( nx.multi_source_shortest_path_length( self._graph, sources, cutoff=cutoff ) ) except Exception as exc: logger.debug( "multi_source_shortest_path_length failed: %s", exc, exc_info=True ) return {} def _fallback_distances( self, motifs: Dict[str, Set[Node]], cutoff: int ) -> Dict[str, Dict[Node, int]]: distances_by_motif: Dict[str, Dict[Node, int]] = {} try: all_sp = dict(nx.all_pairs_shortest_path_length(self._graph, cutoff=cutoff)) for name, sources in motifs.items(): db: Dict[Node, int] = {} if not sources: distances_by_motif[name] = db continue for node in self._graph.nodes(): dmap = all_sp.get(node, {}) found = [dmap.get(s) for s in sources if s in dmap] db[node] = min(found) if found else cutoff distances_by_motif[name] = db return distances_by_motif except Exception as exc: logger.debug( "Fallback all-pairs shortest path failed: %s", exc, exc_info=True ) return {name: {} for name in motifs.keys()} def _distances_to_motifs(self) -> "GraphAnnotator": maxd = int(self._max_distance) try: motifs = self._collect_motif_sets() distances_by_motif: Dict[str, Dict[Node, int]] = {} for name, sources in motifs.items(): distances_by_motif[name] = self._compute_distances_for_motif( sources, cutoff=maxd ) need_fallback = any( len(motifs[name]) > 0 and not distances_by_motif.get(name) for name in motifs.keys() ) if need_fallback: distances_by_motif = self._fallback_distances(motifs, cutoff=maxd) for n in self._graph.nodes(): for name in motifs.keys(): dmap = distances_by_motif.get(name, {}) dist = dmap.get(n, maxd) self._graph.nodes[n][f"dist_to_{name}"] = int(dist) except Exception as exc: logger.debug( "Failed to compute distances to motifs: %s", exc, exc_info=True ) for n in self._graph.nodes(): for name in ("halogen", "hetero", "aromatic", "carbonyl"): self._graph.nodes[n].setdefault(f"dist_to_{name}", maxd) return self def __repr__(self) -> str: try: n = self._graph.number_of_nodes() except Exception: n = -1 return f"{self.__class__.__name__}(in_place={self._in_place}, n_nodes={n})"
[docs] @classmethod def help(cls) -> str: return ( "GraphAnnotator.help() -> str\n\n" "This annotator computes: atom_degree, nbr_elements_counts_r1, dist_to_<motif>,\n" "conj_component_size, ring_sizes and in_ring. Call `.annotate()` to run\n" "the pipeline and retrieve results via `.graph`." )