API Reference
ArchetypalAnalysis
- class archetypax.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Bases:
BaseEstimator,TransformerMixinGPU-accelerated Archetypal Analysis implementation using JAX.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Initialize the Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to find
max_iter – Maximum number of iterations
tol – Convergence tolerance
random_seed – Random seed for initialization
learning_rate – Learning rate for optimizer (reduced for better stability)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]
Fit the model to the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Self
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]
Fit the model and transform the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
normalize – Whether to normalize the data before fitting.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- get_loss_history() list[float][source]
Get the loss history from training.
- Returns:
List of loss values recorded during fitting
- project_archetypes(archetypes, X)[source]
Project archetypes using soft assignment based on k-nearest neighbors.
- Parameters:
archetypes – Archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix
- project_weights(weights)[source]
Project weights to satisfy simplex constraints with numerical stability.
- Parameters:
weights – Weight matrix
- Returns:
Projected weight matrix
- reconstruct(X: ndarray = None) ndarray[source]
Reconstruct data using the learned archetypes.
- Parameters:
X – Data to reconstruct, or None to use the training data
- Returns:
Reconstructed data
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform new data to archetype weights.
- Parameters:
X – New data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Weight matrix representing each sample as a combination of archetypes