Source code for synkit.Vis.embedding
from typing import Any, Dict, Optional
import numpy as np
from sklearn.manifold import TSNE
from joblib import Memory
[docs]
class Embedding:
def __init__(
self,
cache_dir: str = "./cachedir",
verbose: int = 0,
custom_tsne_params: Optional[Dict] = None,
) -> None:
"""Initialize the Embedding class with options for caching directory,
verbosity, and custom t-SNE parameters.
Parameters:
cache_dir (str): Directory where cached results are stored.
verbose (int): Verbosity level for the memory object.
custom_tsne_params (Dict, optional): Custom default parameters for t-SNE computations.
"""
self.memory = Memory(cache_dir, verbose=verbose)
self.default_tsne_params = {
"n_components": 2,
"perplexity": 30,
"learning_rate": 200,
"max_iter": 1000,
"random_state": 42,
}
if custom_tsne_params:
self.default_tsne_params.update(custom_tsne_params)
self.tsne_params = self.default_tsne_params.copy()
[docs]
def set_tsne_params(self, **params) -> None:
"""Sets parameters for t-SNE computations.
Parameters:
**params: Arbitrary number of parameters for t-SNE.
"""
self.tsne_params.update(params)
[docs]
def reset_tsne_params(self) -> None:
"""Resets t-SNE parameters to default values."""
self.tsne_params = self.default_tsne_params.copy()
def _compute_tsne(self, X: np.ndarray) -> np.ndarray:
"""Direct computation of the t-SNE embedding with the current
parameters.
Parameters:
X (np.ndarray): High-dimensional data points.
Returns:
np.ndarray: The 2-dimensional t-SNE embedding of the data.
"""
tsne = TSNE(**self.tsne_params)
return tsne.fit_transform(X)
[docs]
def compute_tsne(self, X: np.ndarray, cache: bool = True) -> np.ndarray:
"""Computes or retrieves the t-SNE embedding from cache.
Parameters:
X (np.ndarray): High-dimensional data points.
cache (bool): Determines whether to use caching for the computation.
Returns:
np.ndarray: The 2-dimensional t-SNE embedding of the data.
"""
if cache:
return self.cache(X)
else:
return self._compute_tsne(X)
@property
def cache(self) -> Any:
"""Decorator for caching the compute_tsne function.
Returns:
Callable: Cached function.
"""
return self.memory.cache(self._compute_tsne)
[docs]
def clear_cache(self) -> None:
"""Clears the cache directory."""
self.memory.clear()