API Reference

ArchetypalAnalysis

class archetypax.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Bases: BaseEstimator, TransformerMixin

GPU-accelerated Archetypal Analysis implementation using JAX.

This class provides the core functionality for identifying archetypes - extreme points that can represent data through convex combinations,

offering interpretable and meaningful insights into data structure.

Leverages JAX for efficient GPU computation and automatic differentiation.

__init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Initialize the Archetypal Analysis model.

Parameters:
  • n_archetypes – Number of archetypes to find - determines the dimensionality of the representation space

  • max_iter – Maximum number of iterations for optimization convergence

  • tol – Convergence tolerance for early stopping

  • random_seed – Random seed for reproducible results

  • learning_rate – Learning rate for optimizer - lower values provide better stability at the cost of slower convergence

  • normalize – Whether to normalize the data before fitting - essential for features with different scales

  • **kwargs

    Additional keyword arguments including: early_stopping_patience:

    Number of iterations without improvement before stopping optimization

    logger_level/verbose_level:

    Control for logging granularity

fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]

Fit the model to the data.

Identifies optimal archetypes and weights through iterative optimization. Uses Adam optimizer with projection steps to ensure constraints are satisfied.

Parameters:
  • X – Data matrix (n_samples, n_features)

  • normalize – Whether to normalize the data before fitting

  • **kwargs – Additional keyword arguments for fine-tuning the fitting process

Returns:

Self - fitted model instance

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]

Fit the model and transform the data.

Parameters:
  • X – Data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • normalize – Whether to normalize the data before fitting.

Returns:

Weight matrix representing each sample as a combination of archetypes

get_loss_history() list[float][source]

Get the loss history from training.

Returns:

List of loss values recorded during fitting

loss_function(archetypes: Array, weights: Array, X: Array) Array[source]

Calculate reconstruction loss with entropy regularization.

Computes the fundamental objective: minimize reconstruction error while encouraging more discriminative weights through entropy regularization.

Parameters:
  • archetypes – Archetype matrix (n_archetypes, n_features)

  • weights – Weight matrix (n_samples, n_archetypes)

  • X – Data matrix (n_samples, n_features)

Returns:

Combined loss value as a scalar

project_archetypes(archetypes: Array, X: Array) Array[source]

Project archetypes using soft assignment based on k-nearest neighbors.

Ensures archetypes remain within the convex hull of data points by creating soft assignments based on proximity. This approach offers better stability than hard assignment methods.

Parameters:
  • archetypes – Archetype matrix (n_archetypes, n_features)

  • X – Original data matrix (n_samples, n_features)

Returns:

Projected archetype matrix (n_archetypes, n_features)

project_weights(weights: Array) Array[source]

Project weights to satisfy simplex constraints with numerical stability.

Ensures that weights form valid convex combinations (non-negative and sum to 1) while avoiding numerical underflow/overflow issues.

Parameters:

weights – Weight matrix (n_samples, n_archetypes)

Returns:

Projected weight matrix (n_samples, n_archetypes)

reconstruct(X: ndarray | None = None) ndarray[source]

Reconstruct data using the learned archetypes.

Parameters:

X – Data to reconstruct, or None to use the training data

Returns:

Reconstructed data of shape (n_samples, n_features)

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]

Transform new data to archetype weights.

Parameters:
  • X – New data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • **kwargs – Additional keyword arguments for the transform method.

Returns:

Weight matrix representing each sample as a combination of archetypes