archetypax.models package
Submodules
archetypax.models.archetypes module
Improved Archetypal Analysis model using JAX.
- class archetypax.models.archetypes.ArchetypeTracker(*args, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisA specialized subclass designed to monitor the movement of archetypes.
- __init__(*args, **kwargs)[source]
Initialize the ArchetypeTracker with parameters identical to those of ImprovedArchetypalAnalysis.
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypeTracker[source]
Train the model while documenting the positions of archetypes at each iteration.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Self
- project_archetypes(archetypes, X)[source]
Override parent class projection with adaptive version for tracking.
- project_archetypes_with_adaptive_strength(archetypes, X)[source]
Modified projection function that adapts its strength based on the current iteration.
In early iterations, the projection is very gentle to prevent large movements. As training progresses, it gradually increases to the normal projection strength.
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypeTracker
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- visualize_boundary_proximity(figsize=(10, 5))[source]
Visualize how close archetypes stayed to the convex hull boundary.
- Returns:
matplotlib figure
- visualize_movement(feature_indices=None, figsize=(12, 8), interval=1)[source]
Visualize how archetypes moved during optimization.
- Parameters:
feature_indices – Indices of features to use for 2D plot. If None, PCA is used.
figsize – Figure size.
interval – Plot every nth iteration to avoid overcrowding.
- Returns:
matplotlib figure
- class archetypax.models.archetypes.ImprovedArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]
Bases:
ArchetypalAnalysisImproved Archetypal Analysis model using JAX.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]
Initialize the Improved Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to find
max_iter – Maximum number of iterations
tol – Convergence tolerance
random_seed – Random seed for initialization
learning_rate – Learning rate for optimization
lambda_reg – Regularization parameter
normalize – Whether to normalize the data
projection_alpha – Weight for extreme point
projection_method – Method for projecting archetypes - “cbap”: Use CBAP projection - “convex_hull”: Use convex hull vertices - “knn”: Use k-nearest neighbors
archetype_init_method – Method for initializing archetypes - “directional” or “direction”: Use directions from a sphere - “qhull” or “convex_hull”: Use convex hull vertices - “kmeans” or “kmeans++”: Use k-means++ initialization
**kwargs –
Additional keyword arguments - early_stopping_patience: Number of iterations to wait before stopping if no improvement - verbose_level: Level of verbosity (0, 1, 2)
0: No verbose
1: Basic verbose
2: Detailed verbose
- directional_init(X_jax, n_samples, n_features)[source]
Generate directions using points that are evenly distributed on a sphere.
- fit(X: ndarray, normalize: bool = False, **kwargs) ImprovedArchetypalAnalysis[source]
Fit the model with improved k-means++ initialization.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
The fitted model.
- Return type:
self
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]
Fit the model and return the transformed data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- kmeans_pp_init(X_jax, n_samples, n_features)[source]
More efficient k-means++ style initialization using JAX.
- loss_function(archetypes, weights, X)[source]
Customized loss function with reduced boundary incentive for ArchetypeTracker.
This overrides the parent class loss function to provide more stability during tracking. Specifically, it reduces the boundary incentive weight to prevent archetypes from moving too rapidly in early iterations.
- project_archetypes(archetypes, X) Array[source]
JIT-compiled archetype projection that pushes archetypes towards the convex hull boundary.
Instead of using k-NN which tends to pull archetypes inside the convex hull, this implementation pushes archetypes towards the boundary of the convex hull by finding extreme points in the direction of each archetype.
Technical details: - Boundary Projection Approach: Projects data points along the direction from the
data centroid to each archetype, then identifies the most extreme point in that direction. This effectively “pushes” archetypes toward the convex hull boundary rather than pulling them inward.
Stability Enhancement: Blends the original archetype with the extreme point using a weighted average (20% extreme point, 80% original archetype) to prevent abrupt changes and ensure optimization stability.
- Parameters:
archetypes – Current archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix positioned closer to the convex hull boundary
- project_archetypes_convex_hull(archetypes, X) Array[source]
Alternative archetype projection that uses convex combinations of extreme points.
This method identifies potential extreme points and creates archetypes as sparse convex combinations of these points, ensuring they lie on the boundary.
Technical details: - Multi-directional Exploration: Generates multiple random directions around the
main archetype direction, allowing for more diverse extreme point discovery.
Sparse Convex Combinations: Creates archetypes as weighted combinations of extreme points found in different directions, with emphasis on the main direction.
Boundary Positioning: By using convex combinations of extreme points, archetypes are positioned on or near the convex hull boundary rather than in its interior.
This approach offers potentially better exploration of the convex hull boundary at the cost of slightly higher computational complexity.
- Parameters:
archetypes – Current archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix positioned on the convex hull boundary
- project_archetypes_knn(archetypes, X) Array[source]
Original k-NN based archetype projection (kept for comparison).
This method tends to pull archetypes inside the convex hull due to its averaging nature, which is suboptimal for archetypal analysis where archetypes should ideally lie on the convex hull boundary.
- Parameters:
archetypes – Current archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix (typically positioned inside the convex hull)
- qhull_init(X_jax, n_samples, n_features)[source]
Initialize archetypes using convex hull vertices via QHull algorithm.
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ImprovedArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform new data to archetype weights using optimized methods.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- update_archetypes(archetypes, weights, X) Array[source]
Alternative archetype update strategy based on weighted reconstruction.
This approach directly optimizes archetypes by computing the pseudo-inverse of weights, which often provides a more targeted and mathematically sound update than gradient descent for this specific subproblem.
archetypax.models.base module
GPU-accelerated Archetypal Analysis implementation using JAX.
- class archetypax.models.base.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Bases:
BaseEstimator,TransformerMixinGPU-accelerated Archetypal Analysis implementation using JAX.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Initialize the Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to find
max_iter – Maximum number of iterations
tol – Convergence tolerance
random_seed – Random seed for initialization
learning_rate – Learning rate for optimizer (reduced for better stability)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]
Fit the model to the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Self
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]
Fit the model and transform the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
normalize – Whether to normalize the data before fitting.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- get_loss_history() list[float][source]
Get the loss history from training.
- Returns:
List of loss values recorded during fitting
- project_archetypes(archetypes, X)[source]
Project archetypes using soft assignment based on k-nearest neighbors.
- Parameters:
archetypes – Archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix
- project_weights(weights)[source]
Project weights to satisfy simplex constraints with numerical stability.
- Parameters:
weights – Weight matrix
- Returns:
Projected weight matrix
- reconstruct(X: ndarray = None) ndarray[source]
Reconstruct data using the learned archetypes.
- Parameters:
X – Data to reconstruct, or None to use the training data
- Returns:
Reconstructed data
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform new data to archetype weights.
- Parameters:
X – New data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Weight matrix representing each sample as a combination of archetypes
archetypax.models.biarchetypes module
Biarchetypal Analysis using JAX.
- class archetypax.models.biarchetypes.BiarchetypalAnalysis(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisBiarchetypal Analysis using JAX.
This implementation follows the paper “Biarchetype analysis: simultaneous learning of observations and features based on extremes” by Alcacer et al.
Biarchetypal Analysis extends archetype analysis to simultaneously identify archetypes of both observations (rows) and features (columns). It represents the data matrix X as:
X ≃ alpha·beta·X·theta·gamma
where: - alpha, beta: Coefficients and archetypes for observations (rows) - theta, gamma: Coefficients and archetypes for features (columns)
- __init__(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]
Initialize the Biarchetypal Analysis model.
- Parameters:
n_row_archetypes – Number of archetypes for rows (observations)
n_col_archetypes – Number of archetypes for columns (features)
max_iter – Maximum number of iterations
tol – Convergence tolerance
random_seed – Random seed for initialization
learning_rate – Learning rate for optimizer
projection_method – Method for projecting archetypes
lambda_reg – Regularization parameter
**kwargs – Additional keyword arguments
- fit(X: ndarray, normalize: bool = False, **kwargs) BiarchetypalAnalysis[source]
Fit the Biarchetypal Analysis model to the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Self
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) Any[source]
Fit the model and transform the data.
- Parameters:
X – Data matrix
y – Target values (ignored)
normalize – Whether to normalize the data
**kwargs – Additional keyword arguments for the fit_transform method.
- Returns:
Tuple of (row_weights, col_weights)
- get_all_archetypes() tuple[ndarray, ndarray][source]
Get both row and column archetypes.
- Returns:
Tuple of (row_archetypes, column_archetypes)
- get_all_weights() tuple[ndarray, ndarray][source]
Get both row and column weights.
- Returns:
Tuple of (row_weights, column_weights)
- get_biarchetypes() ndarray[source]
Get the biarchetypes matrix.
- Returns:
Biarchetypes matrix of shape (n_row_archetypes, n_col_archetypes)
- loss_function(params: dict[str, Array], X: Array) Array[source]
Calculate the reconstruction loss for biarchetypal analysis.
- Parameters:
params – Dictionary containing alpha, beta, theta, gamma
X – Data matrix
- Returns:
Reconstruction loss
- project_col_archetypes(archetypes: Array, X: Array) Array[source]
Project column archetypes to be convex combinations of features.
This implementation employs a sophisticated feature-space boundary-seeking algorithm that: 1. Identifies multiple extreme features in the direction of each archetype 2. Uses adaptive weighting based on feature importance 3. Ensures proper simplex constraints while maximizing diversity
- Parameters:
archetypes – Column archetype matrix (n_features, n_col_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Projected column archetype matrix with enhanced diversity
- project_col_coefficients(coefficients: Array) Array[source]
Project column coefficients to satisfy simplex constraints.
- Parameters:
coefficients – Coefficient matrix (n_col_archetypes, n_features)
- Returns:
Projected coefficient matrix
- project_row_archetypes(archetypes: Array, X: Array) Array[source]
Project row archetypes to be convex combinations of data points.
This implementation employs an advanced boundary-seeking algorithm that: 1. Identifies multiple extreme points in the direction of each archetype 2. Uses adaptive weighting to balance diversity and stability 3. Ensures proper simplex constraints are maintained
- Parameters:
archetypes – Row archetype matrix (n_row_archetypes, n_samples)
X – Data matrix (n_samples, n_features)
- Returns:
Projected row archetype matrix with enhanced diversity
- project_row_coefficients(coefficients: Array) Array[source]
Project row coefficients to satisfy simplex constraints.
- Parameters:
coefficients – Coefficient matrix (n_samples, n_row_archetypes)
- Returns:
Projected coefficient matrix
- reconstruct(X: ndarray = None) ndarray[source]
Reconstruct data from biarchetypes.
- Parameters:
X – Optional data matrix to reconstruct. If None, uses the training data.
- Returns:
Reconstructed data matrix
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') BiarchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) Any[source]
Transform new data to row and column archetype weights.
- Parameters:
X – New data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Tuple of (row_weights, col_weights) representing the data in terms of archetypes
archetypax.models.sparse_archetypes module
Sparse Archetypal Analysis model utilizing JAX.
- class archetypax.models.sparse_archetypes.SparseArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisArchetypal Analysis incorporating sparsity constraints on archetypes.
This implementation enhances the ImprovedArchetypalAnalysis by introducing sparsity constraints to the archetypes, thereby improving interpretability, particularly in high-dimensional datasets.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]
Initialize the Sparse Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to extract.
max_iter – Maximum number of iterations for optimization.
tol – Convergence tolerance.
random_seed – Random seed for reproducibility.
learning_rate – Learning rate for the optimizer.
lambda_reg – Regularization strength for weights.
lambda_sparsity – Regularization strength for archetype sparsity.
sparsity_method – Method for enforcing sparsity (“l1”, “l0_approx”, or “feature_selection”).
normalize – Whether to normalize data prior to fitting.
projection_method – Method for projecting archetypes (“cbap”, “convex_hull”, or “knn”).
projection_alpha – Strength of projection (0-1).
archetype_init_method – Method for initializing archetypes (“directional”, “qhull”, “kmeans_pp”).
min_volume_factor – Minimum volume factor to prevent degeneracy (0-1).
**kwargs – Additional keyword arguments.
- diversify_archetypes(archetypes, X)[source]
Non-differentiable post-processing step to ensure archetype diversity.
This method guarantees that the archetypes form a non-degenerate simplex with sufficient volume. It is invoked outside the JAX-compiled update steps, as it employs non-differentiable operations.
- Parameters:
archetypes – Current archetypes array.
X – Data matrix.
- Returns:
Diversified archetypes array.
- fit(X: ndarray, normalize: bool = False, **kwargs) SparseArchetypalAnalysis[source]
Fit the Sparse Archetypal Analysis model to the data.
- Parameters:
X – Input data matrix of shape (n_samples, n_features).
normalize – Whether to normalize data prior to fitting.
**kwargs – Additional keyword arguments.
- Returns:
Fitted model instance.
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]
Fit the Sparse Archetypal Analysis model to the data and return the transformed data.
- Parameters:
X – Input data matrix of shape (n_samples, n_features).
y – Target values (not used).
normalize – Whether to normalize data prior to fitting.
**kwargs – Additional keyword arguments.
- Returns:
Transformed data matrix of shape (n_samples, n_archetypes).
- get_archetype_sparsity() ndarray[source]
Calculate the sparsity of each archetype.
- Returns:
Array containing the sparsity score for each archetype. Higher values indicate more sparse archetypes.
- loss_function(archetypes, weights, X)[source]
JIT-compiled loss function incorporating a sparsity constraint on archetypes.
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') SparseArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
Module contents
Core model implementations for Archetypal Analysis.