API Reference

ArchetypalAnalysis

class archetypax.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Bases: BaseEstimator, TransformerMixin

GPU-accelerated Archetypal Analysis implementation using JAX.

__init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Initialize the Archetypal Analysis model.

Parameters:
  • n_archetypes – Number of archetypes to find

  • max_iter – Maximum number of iterations

  • tol – Convergence tolerance

  • random_seed – Random seed for initialization

  • learning_rate – Learning rate for optimizer (reduced for better stability)

  • normalize – Whether to normalize the data before fitting.

  • **kwargs – Additional keyword arguments for the fit method.

fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]

Fit the model to the data.

Parameters:
  • X – Data matrix of shape (n_samples, n_features)

  • normalize – Whether to normalize the data before fitting.

  • **kwargs – Additional keyword arguments for the fit method.

Returns:

Self

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]

Fit the model and transform the data.

Parameters:
  • X – Data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • normalize – Whether to normalize the data before fitting.

Returns:

Weight matrix representing each sample as a combination of archetypes

get_loss_history() list[float][source]

Get the loss history from training.

Returns:

List of loss values recorded during fitting

loss_function(archetypes, weights, X)[source]

Add regularization term to the loss function.

project_archetypes(archetypes, X)[source]

Project archetypes using soft assignment based on k-nearest neighbors.

Parameters:
  • archetypes – Archetype matrix

  • X – Original data matrix

Returns:

Projected archetype matrix

project_weights(weights)[source]

Project weights to satisfy simplex constraints with numerical stability.

Parameters:

weights – Weight matrix

Returns:

Projected weight matrix

reconstruct(X: ndarray = None) ndarray[source]

Reconstruct data using the learned archetypes.

Parameters:

X – Data to reconstruct, or None to use the training data

Returns:

Reconstructed data

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]

Transform new data to archetype weights.

Parameters:
  • X – New data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • **kwargs – Additional keyword arguments for the transform method.

Returns:

Weight matrix representing each sample as a combination of archetypes