Source code for synkit.Chem.Molecule.atom_features

from __future__ import annotations

from collections import deque
from typing import Any, Callable, Dict, List, Optional, Protocol, runtime_checkable

from rdkit import Chem

from .valence import ValenceResolver
from .descriptors import PerMolDescriptors


@runtime_checkable
class _AtomLike(Protocol):
    """
    Minimal protocol for RDKit Atom-like objects used by AtomFeatureExtractor.
    """

    def GetIdx(self) -> int: ...
    def GetSymbol(self) -> str: ...
    def GetIsAromatic(self) -> bool: ...
    def GetTotalNumHs(self) -> int: ...
    def GetFormalCharge(self) -> int: ...
    def GetNumRadicalElectrons(self) -> int: ...
    def GetHybridization(self) -> Any: ...
    def IsInRing(self) -> bool: ...
    def GetNumImplicitHs(self) -> int: ...
    def GetNeighbors(self): ...
    def GetAtomMapNum(self) -> int: ...
    def GetChiralTag(self) -> Any: ...
    def HasProp(self, name: str) -> bool: ...
    def GetProp(self, name: str) -> str: ...
    def GetDoubleProp(self, name: str) -> float: ...
    def GetDegree(self) -> int: ...
    def GetIsotope(self) -> int: ...


[docs] class AtomFeatureExtractor: """ Build per-atom feature dictionaries for an RDKit molecule. The extractor supports two profiles: - ``"minimal"`` : a compact set of attributes (backwards compatible with the original `_gather_atom_properties`). - ``"full"`` : includes valence, ring sizes, neighbor counts, shortest distances to functional groups, and optional descriptors from :class:`PerMolDescriptors`. The class exposes a fluent API: ``.build(atom)`` returns ``self`` and stores the result in ``.feature`` (dict). For batch processing use ``.build_all()`` and read ``.all_features`` afterwards. For one-off usage, the compatibility helper ``.build_dict(atom)`` returns the feature dict directly. :param mol: RDKit molecule to extract features from. :param per: Optional precomputed per-atom descriptors (EState, Crippen, etc.). :param profile: Feature profile to compute (``"minimal"`` or ``"full"``). """ SUPPORTED_PROFILES = ("minimal", "full") def __init__( self, mol: Chem.Mol, per: Optional[PerMolDescriptors] = None, profile: str = "minimal", ): if profile not in self.SUPPORTED_PROFILES: raise ValueError( f"Unsupported profile: {profile!r}. Supported: {self.SUPPORTED_PROFILES}" ) self.mol: Chem.Mol = mol self.per: Optional[PerMolDescriptors] = per self.profile: str = profile # results (filled by build/build_all) self._last_feature: Optional[Dict[str, Any]] = None self._all_features: Optional[List[Dict[str, Any]]] = None # ---------------- fluent / compatibility API -------------------------
[docs] def build(self, atom: Chem.Atom | _AtomLike) -> "AtomFeatureExtractor": """ Compute features for *one* atom and store them internally. Returns self to enable chaining. The result dictionary can be accessed via the ``feature`` property. :param atom: RDKit Atom instance (or Atom-like object). :returns: self """ if self.profile == "full": self._last_feature = self._build_full(atom) else: self._last_feature = self._build_minimal(atom) return self
[docs] def build_dict(self, atom: Chem.Atom | _AtomLike) -> Dict[str, Any]: """ Backwards-compatible helper that returns the computed feature dict directly (does not alter ``.feature`` or ``.all_features``). :param atom: RDKit Atom instance (or Atom-like object). :returns: feature dictionary """ if self.profile == "full": return self._build_full(atom) return self._build_minimal(atom)
[docs] def build_all(self) -> "AtomFeatureExtractor": """ Compute features for *all* atoms in the molecule and store them in ``.all_features``. Returns self for chaining. :returns: self """ features: List[Dict[str, Any]] = [] try: for atom in self.mol.GetAtoms(): features.append(self.build_dict(atom)) except Exception: # Defensive: if iteration fails return empty list features = [] self._all_features = features return self
# ---------------- properties to retrieve results --------------------- @property def feature(self) -> Dict[str, Any]: """ The last computed feature dictionary (via ``build``). :raises RuntimeError: if ``build`` has not been called yet. """ if self._last_feature is None: raise RuntimeError("No features computed yet — call `build(atom)` first.") return dict(self._last_feature) @property def all_features(self) -> List[Dict[str, Any]]: """ List of feature dicts for every atom (populated by ``build_all``). If ``build_all`` was not called, this property will call it lazily. """ if self._all_features is None: self.build_all() assert self._all_features is not None return list(self._all_features) # ---------------- minimal (backwards-compatible) -------------------- def _build_minimal(self, atom: Chem.Atom | _AtomLike) -> Dict[str, Any]: """ Minimal feature set, intended to match the original helper. :param atom: RDKit Atom instance (or Atom-like object). :returns: dict of features. """ # Gasteiger (tolerant access) gcharge = 0.0 try: if atom.HasProp("_GasteigerCharge"): try: gcharge = float(atom.GetProp("_GasteigerCharge")) except Exception: try: gcharge = float(atom.GetDoubleProp("_GasteigerCharge")) except Exception: gcharge = 0.0 except Exception: gcharge = 0.0 # neighbors list (safe) try: neighbor_symbols = sorted(nb.GetSymbol() for nb in atom.GetNeighbors()) except Exception: neighbor_symbols = [] try: atom_map = int(atom.GetAtomMapNum()) except Exception: atom_map = 0 return { "element": atom.GetSymbol(), "aromatic": bool(atom.GetIsAromatic()), "hcount": int(atom.GetTotalNumHs()), "charge": int(atom.GetFormalCharge()), "radical": int(atom.GetNumRadicalElectrons()), "isomer": self._stereo_atom(atom), "partial_charge": round(float(gcharge), 3), "hybridization": str(atom.GetHybridization()), "in_ring": bool(atom.IsInRing()), "neighbors": neighbor_symbols, "atom_map": atom_map, } # ---------------- full profile (extra descriptors) ------------------ def _build_full(self, atom: Chem.Atom | _AtomLike) -> Dict[str, Any]: """ Full feature set, extends minimal with additional computed properties. :param atom: RDKit Atom instance (or Atom-like object). :returns: dict of features """ d = self._build_minimal(atom) # valence (safe) try: ev = ValenceResolver.explicit(atom) iv = ValenceResolver.implicit(atom) except Exception: ev = 0 iv = 0 # core additions d.update( { "explicit_valence": int(ev), "implicit_valence": int(iv), "valence": int(ev + iv), "total_num_hs": ( int(atom.GetTotalNumHs()) if hasattr(atom, "GetTotalNumHs") else 0 ), "chiral_tag": str(atom.GetChiralTag()), "is_chiral_center": ( bool(atom.HasProp("_ChiralityPossible") and atom.GetDegree() > 0) if hasattr(atom, "HasProp") else False ), "ring_sizes": self._ring_sizes(atom), "nbr_elements_counts_r1": self._neighbor_counts(atom), "dist_to_carbonyl": self._dist_to( lambda a: self._is_carbonyl_atom(a), atom.GetIdx() ), "dist_to_hetero": self._dist_to( lambda a: a.GetSymbol() not in {"C", "H"}, atom.GetIdx() ), "dist_to_halogen": self._dist_to( lambda a: a.GetSymbol() in {"F", "Cl", "Br", "I"}, atom.GetIdx() ), "dist_to_aromatic": self._dist_to( lambda a: a.GetIsAromatic(), atom.GetIdx() ), "alpha_to_carbonyl": any( self._is_carbonyl_atom(nb) for nb in atom.GetNeighbors() ), } ) # Optional: estates/crippen if provided in PerMolDescriptors if self.per is not None: idx = atom.GetIdx() if idx < len(self.per.estate): d["estate"] = float(self.per.estate[idx]) if idx < len(self.per.crippen_logp): d["crippen_logp"] = float(self.per.crippen_logp[idx]) if idx < len(self.per.crippen_mr): d["crippen_mr"] = float(self.per.crippen_mr[idx]) return d # ---------------- helpers (small, well-typed) ------------------------ @staticmethod def _stereo_atom(atom: Chem.Atom | _AtomLike) -> str: """ Map RDKit chiral tags to simple stereodescriptors. Returns "S", "R" or "N" (none/unknown). """ try: ch = atom.GetChiralTag() if ch == Chem.ChiralType.CHI_TETRAHEDRAL_CCW: return "S" if ch == Chem.ChiralType.CHI_TETRAHEDRAL_CW: return "R" except Exception: pass return "N" def _ring_sizes(self, atom: Chem.Atom | _AtomLike) -> List[int]: """ Return list of ring sizes the atom belongs to (empty if none). """ sizes: List[int] = [] try: ri = self.mol.GetRingInfo() for ring in ri.AtomRings(): if atom.GetIdx() in ring: sizes.append(len(ring)) except Exception: pass return sizes @staticmethod def _neighbor_counts(atom: Chem.Atom | _AtomLike) -> Dict[str, int]: """ Count neighbor element occurrences (e.g., {"H": 3, "C": 1}). """ counts: Dict[str, int] = {} try: for nb in atom.GetNeighbors(): s = nb.GetSymbol() counts[s] = counts.get(s, 0) + 1 except Exception: pass return counts def _is_carbonyl_atom(self, a: Chem.Atom | _AtomLike) -> bool: """ Heuristic: carbon atom double-bonded to oxygen (C=O). """ try: if a.GetSymbol() != "C": return False for nb in a.GetNeighbors(): if nb.GetSymbol() == "O": b = self.mol.GetBondBetweenAtoms(a.GetIdx(), nb.GetIdx()) if b is not None and b.GetBondTypeAsDouble() >= 2.0: return True except Exception: pass return False def _dist_to( self, predicate: Callable[[Chem.Atom], bool], start_idx: int, maxd: int = 99 ) -> int: """ Shortest-path distance (in bonds) from atom ``start_idx`` to the first atom that satisfies ``predicate``. Returns ``maxd`` if none found within the search limit. :param predicate: callable that accepts an RDKit atom and returns bool. :param start_idx: starting atom index. :param maxd: maximum distance to search (defaults to 99). :returns: integer distance (0 means start atom satisfies predicate). """ try: seen = {start_idx} dq = deque([(start_idx, 0)]) while dq: idx, d = dq.popleft() a = self.mol.GetAtomWithIdx(idx) try: if predicate(a): return d except Exception: # ignore predicate failures for robustness pass if d >= maxd: continue for nb in a.GetNeighbors(): ni = nb.GetIdx() if ni not in seen: seen.add(ni) dq.append((ni, d + 1)) except Exception: pass return maxd # ---------------- utilities / metadata -------------------------------- def __repr__(self) -> str: try: n_atoms = self.mol.GetNumAtoms() except Exception: n_atoms = -1 return f"{self.__class__.__name__}(profile={self.profile!r}, n_atoms={n_atoms})"
[docs] @classmethod def help(cls) -> str: """ Short machine-readable help describing supported profiles. """ return ( "AtomFeatureExtractor.help() -> str\n\n" "Supported profiles:\n" " - 'minimal': compact, standard atom properties\n" " - 'full' : includes valence, ring sizes, neighbor counts, " "distances to functional groups, and optional PerMolDescriptors fields\n" )