API Reference
ArchetypalAnalysis
- class archetypax.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Bases:
BaseEstimator,TransformerMixinGPU-accelerated Archetypal Analysis implementation using JAX.
This class provides the core functionality for identifying archetypes - extreme points that can represent data through convex combinations, offering interpretable and meaningful insights into data structure.
Leverages JAX for efficient GPU computation and automatic differentiation.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Initialize the Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to find - determines the dimensionality of the representation space
max_iter – Maximum number of iterations for optimization convergence
tol – Convergence tolerance for early stopping
random_seed – Random seed for reproducible results
learning_rate – Learning rate for optimizer - lower values provide better stability at the cost of slower convergence
normalize – Whether to normalize the data before fitting - essential for features with different scales
**kwargs –
Additional keyword arguments including: early_stopping_patience:
Number of iterations without improvement before stopping optimization
- logger_level/verbose_level:
Control for logging granularity
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]
Fit the model to the data.
Identifies optimal archetypes and weights through iterative optimization. Uses Adam optimizer with projection steps to ensure constraints are satisfied.
- Parameters:
X – Data matrix (n_samples, n_features)
normalize – Whether to normalize the data before fitting
**kwargs – Additional keyword arguments for fine-tuning the fitting process
- Returns:
Self - fitted model instance
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]
Fit the model and transform the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
normalize – Whether to normalize the data before fitting.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- get_loss_history() list[float][source]
Get the loss history from training.
- Returns:
List of loss values recorded during fitting
- loss_function(archetypes: Array, weights: Array, X: Array) Array[source]
Calculate reconstruction loss with entropy regularization.
Computes the fundamental objective: minimize reconstruction error while encouraging more discriminative weights through entropy regularization.
- Parameters:
archetypes – Archetype matrix (n_archetypes, n_features)
weights – Weight matrix (n_samples, n_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Combined loss value as a scalar
- project_archetypes(archetypes: Array, X: Array) Array[source]
Project archetypes using soft assignment based on k-nearest neighbors.
Ensures archetypes remain within the convex hull of data points by creating soft assignments based on proximity. This approach offers better stability than hard assignment methods.
- Parameters:
archetypes – Archetype matrix (n_archetypes, n_features)
X – Original data matrix (n_samples, n_features)
- Returns:
Projected archetype matrix (n_archetypes, n_features)
- project_weights(weights: Array) Array[source]
Project weights to satisfy simplex constraints with numerical stability.
Ensures that weights form valid convex combinations (non-negative and sum to 1) while avoiding numerical underflow/overflow issues.
- Parameters:
weights – Weight matrix (n_samples, n_archetypes)
- Returns:
Projected weight matrix (n_samples, n_archetypes)
- reconstruct(X: ndarray | None = None) ndarray[source]
Reconstruct data using the learned archetypes.
- Parameters:
X – Data to reconstruct, or None to use the training data
- Returns:
Reconstructed data of shape (n_samples, n_features)
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform new data to archetype weights.
- Parameters:
X – New data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Weight matrix representing each sample as a combination of archetypes