archetypax.models package
Submodules
archetypax.models.archetypes module
Improved Archetypal Analysis model using JAX.
- class archetypax.models.archetypes.ArchetypeTracker(*args, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisA specialized subclass designed to monitor the movement of archetypes.
- __init__(*args, **kwargs)[source]
Initialize the ArchetypeTracker with parameters identical to those of ImprovedArchetypalAnalysis.
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypeTracker[source]
Train the model while documenting the positions of archetypes at each iteration.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
normalize – Whether to normalize the data before fitting.
**kwargs – Additional keyword arguments for the fit method.
- Returns:
Self
- project_archetypes(archetypes: Array, X: Array) Array[source]
Override parent class projection with adaptive version for tracking.
- Parameters:
archetypes – Archetype matrix of shape (n_archetypes, n_features)
X – Data matrix of shape (n_samples, n_features)
- Returns:
Projected archetype matrix of shape (n_archetypes, n_features)
- project_archetypes_with_adaptive_strength(archetypes: Array, X: Array) Array[source]
Modified projection function that adapts its strength based on the current iteration.
In early iterations, the projection is very gentle to prevent large movements. As training progresses, it gradually increases to the normal projection strength.
- Parameters:
archetypes – Archetype matrix of shape (n_archetypes, n_features)
X – Data matrix of shape (n_samples, n_features)
- Returns:
Projected archetype matrix of shape (n_archetypes, n_features)
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypeTracker
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- visualize_boundary_proximity(figsize=(10, 5)) Any | None[source]
Visualize how close archetypes stayed to the convex hull boundary.
- Parameters:
figsize – Figure size.
- Returns:
matplotlib figure
- visualize_movement(feature_indices: list[int] | None = None, figsize=(12, 8), interval: int = 1) Any | None[source]
Visualize how archetypes moved during optimization.
- Parameters:
feature_indices – Indices of features to use for 2D plot. If None, PCA is used.
figsize – Figure size.
interval – Plot every nth iteration to avoid overcrowding.
- Returns:
matplotlib figure
- class archetypax.models.archetypes.ImprovedArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]
Bases:
ArchetypalAnalysisImproved Archetypal Analysis model using JAX.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]
Initialize an enhanced archetypal analysis model with robust optimization.
This improved implementation addresses key limitations of standard archetypal analysis through advanced initialization strategies, robust gradient-based optimization, and adaptive boundary projection techniques. These enhancements significantly improve convergence stability, solution quality, and computational efficiency across diverse datasets.
- Parameters:
n_archetypes – Number of archetypes to discover - controls the model’s expressiveness and dimensionality reduction ratio. Higher values capture more nuanced patterns at the cost of interpretability and potential overfitting.
max_iter – Maximum optimization iterations - higher values ensure better convergence at the cost of computational time. The default (500) balances solution quality with reasonable runtime for most datasets.
tol – Convergence tolerance - smaller values yield more precise solutions but require more iterations. The default (1e-6) is suitable for most applications, while scientific applications may require smaller values.
random_seed – Random seed for reproducibility - ensures consistent results across runs with the same data and parameters.
learning_rate – Gradient descent step size - critical parameter balancing convergence speed with stability. Too high risks overshooting minima, while too low causes slow convergence.
lambda_reg – Regularization strength for weights - controls the balance between reconstruction accuracy and weight sparsity. Higher values promote more discrete archetype assignments.
normalize – Whether to normalize features - essential when features have different scales to prevent dominance by high-magnitude features. Should be True for most real-world datasets.
projection_method –
Strategy for projecting archetypes to boundary: - “cbap” (default): Convex boundary approximation projection -
balanced approach suitable for most datasets
- ”convex_hull”: Uses exact convex hull vertices - more precise
but computationally intensive for high dimensions
”knn”: K-nearest neighbors approximation - faster for large datasets
projection_alpha – Projection strength parameter (0-1) - controls how aggressively archetypes are pushed toward the boundary. Higher values emphasize extremeness over reconstruction.
archetype_init_method –
Initialization strategy for archetypes: - “directional” (default): Directions from centroid - robust general-purpose
approach that balances diversity with boundary alignment
- ”qhull”/”convex_hull”: Exact convex hull vertices - ideal when
geometric extremes are well-defined
- ”kmeans”/”kmeans++”: K-means++ initialization - beneficial when
density-based initialization aligns with domain expectations
**kwargs –
Additional parameters: - early_stopping_patience: Iterations without improvement before stopping - verbose_level: Controls logging detail (0-4)
0: Critical only
1: Error level
2: Warning level
3: Info level (recommended for monitoring)
4: Debug level (verbose training details)
logger_level: Alternative to verbose_level with reversed mapping
- directional_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]
Generate directions using points that are evenly distributed on a sphere.
- Parameters:
X_jax – Data matrix of shape (n_samples, n_features)
n_samples – Number of samples
n_features – Number of features
- Returns:
Archetypes and archetype indices
- fit(X: ndarray, normalize: bool = False, **kwargs) ImprovedArchetypalAnalysis[source]
Discover optimal archetypes through advanced multi-strategy optimization.
This core method identifies the extreme patterns that define the convex hull of the data and serve as the building blocks for representing all observations. The implementation features several critical enhancements:
Intelligent initialization strategies that target promising positions
Hybrid optimization combining gradient-based and direct algebraic updates
Adaptive boundary projection to ensure archetypes represent true extremes
Improved numerical stability through strategic regularization
Early stopping logic to prevent overfitting and wasted computation
These techniques collectively address the fundamental challenges of archetypal analysis: sensitivity to initialization, convergence to suboptimal solutions, and computational efficiency.
- Parameters:
X – Data matrix to analyze (n_samples, n_features)
normalize – Whether to normalize features before fitting - essential for data with features of different scales or magnitudes
**kwargs –
Additional optimization parameters: - early_stopping_patience: Iterations without improvement before
stopping (higher values ensure convergence at computational cost)
additional parameters specific to the initialization method
- Returns:
Self - fitted model instance with discovered archetypes
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]
Fit the model and immediately transform the input data.
This convenience method combines model fitting and data transformation in a single operation, which offers two key advantages:
Computational efficiency by avoiding redundant calculations
Simplified workflow for immediate archetypal representation
This method is particularly valuable in analysis pipelines or when integrating with scikit-learn compatible frameworks that expect this pattern. It ensures that the transformation is performed with the same preprocessing settings used during fitting.
- Parameters:
X – Data matrix to fit and transform (n_samples, n_features)
y – Ignored, present for scikit-learn API compatibility
normalize – Whether to normalize features before fitting - essential for data with different scales or magnitudes
**kwargs – Additional parameters passed to both fit() and transform(), including optimization settings and convergence criteria
- Returns:
Weight matrix representing each sample as a combination of the discovered archetypes (n_samples, n_archetypes)
- kmeans_pp_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]
More efficient k-means++ style initialization using JAX.
- Parameters:
X_jax – Data matrix of shape (n_samples, n_features)
n_samples – Number of samples
n_features – Number of features
- Returns:
Archetypes and archetype indices
- loss_function(archetypes: Array, weights: Array, X: Array) Array[source]
Composite objective function balancing reconstruction with interpretability.
This carefully designed loss function guides the optimization process by balancing multiple competing objectives essential for archetypal analysis:
Reconstruction fidelity: Ensuring archetypes accurately represent the data
Weight interpretability: Encouraging sparse, distinctive weight patterns
Boundary alignment: Promoting archetypes at meaningful extremal positions
The weighted combination of these terms creates a landscape that guides optimization toward solutions with both mathematical validity (convex hull representation) and practical utility (interpretable patterns). The relative weighting of these components is critical to achieving the right balance between reconstruction accuracy and archetypal properties.
This JIT-compiled implementation ensures computational efficiency during the intensive optimization process.
- Parameters:
archetypes – Candidate archetype matrix (n_archetypes, n_features)
weights – Weight matrix (n_samples, n_archetypes) describing how to represent each sample as a combination of archetypes
X – Original data matrix (n_samples, n_features) to reconstruct
- Returns:
Scalar loss value combining reconstruction error with regularization terms - lower values indicate better solutions
- project_archetypes(archetypes: Array, X: Array) Array[source]
Strategically position archetypes on the convex hull boundary for optimal representation.
This method is critical for meaningful archetypal analysis as it ensures archetypes remain at the extremes of the data distribution where they best represent distinctive patterns. Our implementation differs from standard projection methods by:
Projecting along meaningful directions from the data centroid
Identifying precise extreme points rather than using approximate methods
Blending original positions with boundary points for stability
Applying adaptive adjustments based on current position
- Parameters:
archetypes – Current archetype positions (n_archetypes, n_features)
X – Data matrix defining the convex hull (n_samples, n_features)
- Returns:
Projected archetypes strategically positioned at or near the convex hull boundary
- project_archetypes_convex_hull(archetypes: Array, X: Array) Array[source]
Alternative archetype projection that uses convex combinations of extreme points.
This method identifies potential extreme points and creates archetypes as sparse convex combinations of these points, ensuring they lie on the boundary.
Technical details: - Multi-directional Exploration: Generates multiple random directions around the
main archetype direction, allowing for more diverse extreme point discovery.
Sparse Convex Combinations: Creates archetypes as weighted combinations of extreme points found in different directions, with emphasis on the main direction.
Boundary Positioning: By using convex combinations of extreme points, archetypes are positioned on or near the convex hull boundary rather than in its interior.
This approach offers potentially better exploration of the convex hull boundary at the cost of slightly higher computational complexity.
- Parameters:
archetypes – Current archetype matrix of shape (n_archetypes, n_features)
X – Original data matrix of shape (n_samples, n_features)
- Returns:
Projected archetype matrix positioned on the convex hull boundary
- project_archetypes_knn(archetypes: Array, X: Array) Array[source]
Original k-NN based archetype projection (kept for comparison).
This method tends to pull archetypes inside the convex hull due to its averaging nature, which is suboptimal for archetypal analysis where archetypes should ideally lie on the convex hull boundary.
- Parameters:
archetypes – Current archetype matrix
X – Original data matrix
- Returns:
Projected archetype matrix (typically positioned inside the convex hull)
- project_weights(weights: Array) Array[source]
JIT-compiled weight projection function.
- Parameters:
weights – Weight matrix of shape (n_samples, n_archetypes)
- Returns:
Projected weight matrix of shape (n_samples, n_archetypes)
- qhull_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]
Initialize archetypes using convex hull vertices via QHull algorithm.
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ImprovedArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform data into archetypal weight space with adaptive optimization.
This method computes optimal weights representing each sample as a convex combination of discovered archetypes. The transformation reveals how samples relate to extreme patterns, offering:
Dimensionality reduction while preserving interpretability
Soft clustering based on meaningful archetypes rather than arbitrary centroids
Insights into sample composition and relationship to extreme patterns
A foundation for transfer learning when applying archetypes to new data
Multiple optimization strategies are available, with adaptive selection based on dataset size to balance computational efficiency with solution quality.
- Parameters:
X – Data matrix to transform (n_samples, n_features)
y – Ignored, present for scikit-learn API compatibility
**kwargs –
Additional parameters: - method: Optimization approach to use:
”lbfgs”: Best for small datasets (<1000 samples)
”adam”: Balanced option for mid-sized data (default)
”sgd”: Memory-efficient for large datasets
”adaptive”: Automatically selects based on data size
max_iter: Maximum iterations for weight optimization
tol: Convergence tolerance (smaller values for more precision)
- Returns:
Weight matrix representing each sample as a combination of the discovered archetypes (n_samples, n_archetypes)
- update_archetypes(archetypes: Array, weights: Array, X: Array) Array[source]
Alternative archetype update strategy based on weighted reconstruction.
This approach directly optimizes archetypes by computing the pseudo-inverse of weights, which often provides a more targeted and mathematically sound update than gradient descent for this specific subproblem.
- Parameters:
archetypes – Archetype matrix of shape (n_archetypes, n_features)
weights – Weight matrix of shape (n_samples, n_archetypes)
X – Data matrix of shape (n_samples, n_features)
- Returns:
Updated archetype matrix of shape (n_archetypes, n_features)
archetypax.models.base module
GPU-accelerated Archetypal Analysis implementation using JAX.
This module provides a foundational implementation of Archetypal Analysis (AA) optimized for GPU acceleration via JAX. AA identifies extreme points (archetypes) that can represent the entire dataset through convex combinations.
- class archetypax.models.base.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Bases:
BaseEstimator,TransformerMixinGPU-accelerated Archetypal Analysis implementation using JAX.
This class provides the core functionality for identifying archetypes - extreme points that can represent data through convex combinations, offering interpretable and meaningful insights into data structure.
Leverages JAX for efficient GPU computation and automatic differentiation.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]
Initialize the Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to find - determines the dimensionality of the representation space
max_iter – Maximum number of iterations for optimization convergence
tol – Convergence tolerance for early stopping
random_seed – Random seed for reproducible results
learning_rate – Learning rate for optimizer - lower values provide better stability at the cost of slower convergence
normalize – Whether to normalize the data before fitting - essential for features with different scales
**kwargs –
Additional keyword arguments including: early_stopping_patience:
Number of iterations without improvement before stopping optimization
- logger_level/verbose_level:
Control for logging granularity
- fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]
Fit the model to the data.
Identifies optimal archetypes and weights through iterative optimization. Uses Adam optimizer with projection steps to ensure constraints are satisfied.
- Parameters:
X – Data matrix (n_samples, n_features)
normalize – Whether to normalize the data before fitting
**kwargs – Additional keyword arguments for fine-tuning the fitting process
- Returns:
Self - fitted model instance
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]
Fit the model and transform the data.
- Parameters:
X – Data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
normalize – Whether to normalize the data before fitting.
- Returns:
Weight matrix representing each sample as a combination of archetypes
- get_loss_history() list[float][source]
Get the loss history from training.
- Returns:
List of loss values recorded during fitting
- loss_function(archetypes: Array, weights: Array, X: Array) Array[source]
Calculate reconstruction loss with entropy regularization.
Computes the fundamental objective: minimize reconstruction error while encouraging more discriminative weights through entropy regularization.
- Parameters:
archetypes – Archetype matrix (n_archetypes, n_features)
weights – Weight matrix (n_samples, n_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Combined loss value as a scalar
- project_archetypes(archetypes: Array, X: Array) Array[source]
Project archetypes using soft assignment based on k-nearest neighbors.
Ensures archetypes remain within the convex hull of data points by creating soft assignments based on proximity. This approach offers better stability than hard assignment methods.
- Parameters:
archetypes – Archetype matrix (n_archetypes, n_features)
X – Original data matrix (n_samples, n_features)
- Returns:
Projected archetype matrix (n_archetypes, n_features)
- project_weights(weights: Array) Array[source]
Project weights to satisfy simplex constraints with numerical stability.
Ensures that weights form valid convex combinations (non-negative and sum to 1) while avoiding numerical underflow/overflow issues.
- Parameters:
weights – Weight matrix (n_samples, n_archetypes)
- Returns:
Projected weight matrix (n_samples, n_archetypes)
- reconstruct(X: ndarray | None = None) ndarray[source]
Reconstruct data using the learned archetypes.
- Parameters:
X – Data to reconstruct, or None to use the training data
- Returns:
Reconstructed data of shape (n_samples, n_features)
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]
Transform new data to archetype weights.
- Parameters:
X – New data matrix of shape (n_samples, n_features)
y – Ignored. Present for API consistency by convention.
**kwargs – Additional keyword arguments for the transform method.
- Returns:
Weight matrix representing each sample as a combination of archetypes
archetypax.models.biarchetypes module
Biarchetypal Analysis: Dual-perspective archetype discovery using JAX.
This module implements Biarchetypal Analysis (BA), which extends traditional Archetypal Analysis by simultaneously identifying archetypes in both observation space (rows) and feature space (columns). This dual-perspective approach enables more comprehensive data understanding, revealing patterns that would remain hidden in single-direction analysis.
- class archetypax.models.biarchetypes.BiarchetypalAnalysis(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisBiarchetypal Analysis for dual-directional pattern discovery.
This implementation extends archetypal analysis to simultaneously identify extreme patterns in both observations (rows) and features (columns), offering a richer understanding of data structure. Traditional archetypal analysis only identifies patterns in observation space, missing crucial feature-level insights.
By factorizing the data matrix X as: X ≃ alpha·beta·X·theta·gamma
BA provides several advantages: - Captures both observation-level and feature-level patterns - Enables cross-modal analysis between observations and features - Creates a more compact and interpretable representation via biarchetypes - Reveals latent relationships that single-directional methods cannot detect
Based on the work by Alcacer et al., “Biarchetype analysis: simultaneous learning of observations and features based on extremes.”
- __init__(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]
Initialize the Biarchetypal Analysis model.
- Parameters:
n_row_archetypes – Number of row archetypes - controls expressiveness in observation space (rows)
n_col_archetypes – Number of column archetypes - controls expressiveness in feature space (columns)
max_iter – Maximum optimization iterations - higher values enable better convergence at computational cost
tol – Convergence tolerance for early stopping - smaller values yield more precise solutions but require more iterations
random_seed – Random seed for reproducibility across runs
learning_rate – Gradient descent step size - critical balance between convergence speed and stability
projection_method – Method for projecting archetypes to extreme points: “default” uses convex boundary approximation
lambda_reg – Regularization strength - controls sparsity/smoothness tradeoff in archetype weights
**kwargs – Additional parameters including: - early_stopping_patience: Iterations with no improvement before stopping - verbose_level/logger_level: Controls logging detail
- fit(X: ndarray, normalize: bool = False, **kwargs) BiarchetypalAnalysis[source]
Fit the Biarchetypal Analysis model to identify dual-perspective archetypes.
This core method performs the four-factor decomposition of the data matrix, simultaneously discovering patterns in observation and feature spaces. The implementation employs advanced optimization strategies including:
Sophisticated initialization for both row and column factors
Adaptive learning rate scheduling for stable convergence
Specialized projection operations to maintain meaningful boundaries
Careful numerical handling to prevent instability
Early stopping with convergence monitoring
These optimizations are essential due to the complexity of the four-factor model, which is more challenging to optimize than standard Archetypal Analysis.
- Parameters:
X – Data matrix (n_samples, n_features)
normalize – Whether to normalize features - essential for data with different scales
**kwargs – Additional parameters for customizing the fitting process
- Returns:
Self - fitted model instance with discovered biarchetypes
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) Any[source]
Fit the model and transform the data in a single operation.
This convenience method combines model fitting and data transformation in a single step, offering two key advantages: 1. Computational efficiency by avoiding redundant calculations 2. Simplified workflow for immediate biarchetypal representation
The method is particularly valuable when the biarchetypal representation is needed immediately after fitting, such as in analysis pipelines or when integrating with scikit-learn compatible frameworks.
- Parameters:
X – Data matrix to fit and transform (n_samples, n_features)
y – Ignored. Present for scikit-learn API compatibility
normalize – Whether to normalize features before fitting
**kwargs – Additional parameters passed to fit()
- Returns:
Tuple of (row_weights, col_weights) representing the data in biarchetypal space
- get_all_archetypes() tuple[ndarray, ndarray][source]
Retrieve both row and column archetypes in a single call.
This convenience method provides access to both directions of archetypal analysis simultaneously, facilitating comprehensive analysis and visualization of the dual-perspective patterns. Accessing both archetypes together is particularly valuable for cross-modal analysis examining relationships between observation patterns and feature patterns.
- Returns:
Tuple of (row_archetypes, column_archetypes) matrices
- get_all_weights() tuple[ndarray, ndarray][source]
Retrieve both row and column weights in a single call.
This convenience method provides access to all weight coefficients simultaneously, enabling comprehensive analysis of how observations and features relate to their respective archetypes. This unified view is particularly valuable for understanding the full biarchetypal decomposition and how information flows between the row and column spaces.
- Returns:
Tuple of (row_weights, column_weights) matrices
- get_biarchetypes() ndarray[source]
Retrieve the core biarchetypes matrix.
The biarchetypes matrix (Z = β·X·θ) represents the heart of the model, capturing the essential patterns at the intersection of row and column archetypes. This matrix provides a compact representation of the data’s underlying structure, with each element representing a specific row-column archetype interaction.
Access to this matrix is essential for visualization, interpretation, and advanced analysis of the identified patterns.
- Returns:
Biarchetypes matrix (n_row_archetypes, n_col_archetypes)
- get_col_archetypes() ndarray[source]
Retrieve the column archetypes.
Column archetypes represent extreme patterns in feature space, describing distinctive feature combinations or “feature types.” This perspective is unique to biarchetypal analysis and provides crucial insights about feature relationships that would be missed in standard archetypal analysis.
These archetypes enable feature-level interpretations and can reveal coordinated feature behaviors across different data contexts.
- Returns:
Column archetypes matrix (n_col_archetypes, n_features)
- get_col_weights() ndarray[source]
Retrieve the column coefficients (gamma).
Column weights represent how each feature is composed as a mixture of column archetypes. These weights provide unique insights into:
Which feature patterns are expressed in each original feature
How features group together based on shared archetype influence
Feature importance through the lens of archetypal patterns
Potential redundancies in the feature space
This feature-space perspective is a distinguishing advantage of biarchetypal analysis compared to standard archetypal methods.
- Returns:
Column weight matrix (n_col_archetypes, n_features)
- get_row_archetypes() ndarray[source]
Retrieve the row archetypes.
Row archetypes represent extreme patterns in observation space, describing distinctive types of data points. These archetypes are essential for understanding the primary modes of variation among observations and provide the foundation for interpreting data point weights.
In the biarchetypal model, row archetypes are projections of the data matrix via the beta coefficients (β·X).
- Returns:
Row archetypes matrix (n_row_archetypes, n_features)
- get_row_weights() ndarray[source]
Retrieve the row coefficients (alpha).
Row weights represent how each data point is composed as a mixture of row archetypes. These weights are essential for:
Understanding which archetype patterns dominate each observation
Clustering similar observations based on their archetype compositions
Detecting anomalies as points with unusual archetype weights
Creating reduced-dimension visualizations based on archetype space
The weights are constrained to be non-negative and sum to 1 (simplex constraint), making them directly interpretable as proportions.
- Returns:
Row weight matrix (n_samples, n_row_archetypes)
- loss_function(params: dict[str, Array], X: Array) Array[source]
Calculate the composite reconstruction loss for biarchetypal factorization.
This core objective function balances reconstruction quality with sparsity promotion to ensure interpretable representations. Unlike standard AA, the biarchetypal loss operates on a four-factor decomposition, requiring careful numerical handling to prevent instability during optimization.
The loss promotes three key properties: 1. Accurate data reconstruction through the biarchetypal representation 2. Sparse coefficients for interpretable patterns 3. Numerical stability through explicit type control
- Parameters:
params – Dictionary containing the four model matrices: - alpha: Row coefficients (n_samples, n_row_archetypes) - beta: Row archetypes (n_row_archetypes, n_samples) - theta: Column archetypes (n_features, n_col_archetypes) - gamma: Column coefficients (n_col_archetypes, n_features)
X – Data matrix (n_samples, n_features)
- Returns:
Combined loss value incorporating reconstruction and regularization terms
- project_col_archetypes(archetypes: Array, X: Array) Array[source]
Project column archetypes to the boundary of the feature space.
This critical counterpart to row archetype projection ensures column archetypes represent distinct feature patterns. While conceptually similar to row projection, this operation works in the transposed space, treating features as observations and finding extremes among them.
Without this specialized projection, the feature archetypes would not capture meaningful feature combinations, undermining the dual-perspective advantage of biarchetypal analysis.
The implementation: 1. Transposes the problem to work in feature space 2. Identifies feature combinations that represent extremes 3. Creates boundary points through weighted feature combinations 4. Maintains numerical stability throughout
- Parameters:
archetypes – Column archetype matrix (n_features, n_col_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Projected column archetypes positioned at the boundaries of feature space, representing distinct feature patterns
- project_col_coefficients(coefficients: Array) Array[source]
Project column coefficients to satisfy simplex constraints.
This projection enforces valid convex combinations in the feature space, which differs critically from row coefficient projection. Feature weights must sum to 1 across columns (not rows), ensuring each feature is properly represented by column archetypes.
This axis-specific projection is a key distinction between standard AA and biarchetypal analysis, enabling the dual-directional nature of the model.
- Parameters:
coefficients – Column coefficient matrix (n_col_archetypes, n_features)
- Returns:
Projected coefficients with each feature’s weights summing to 1, maintaining valid convex combinations in feature space
- project_row_archetypes(archetypes: Array, X: Array) Array[source]
Project row archetypes to the convex hull boundary of data points.
This critical operation ensures row archetypes remain at meaningful extremes of the observation space, where they represent distinct, interpretable patterns. Without this projection, archetypes would tend to collapse toward the data centroid during optimization, losing their representative power.
The implementation uses an adaptive multi-point boundary approximation that: 1. Identifies extreme directions from the data centroid 2. Selects multiple boundary points along each direction 3. Creates weighted combinations that maximize distinctiveness 4. Maintains numeric stability throughout the process
- Parameters:
archetypes – Row archetype matrix (n_row_archetypes, n_samples)
X – Data matrix (n_samples, n_features)
- Returns:
Projected row archetypes positioned at meaningful boundaries of the data’s convex hull
- project_row_coefficients(coefficients: Array) Array[source]
Project row coefficients to satisfy simplex constraints.
This projection is essential for maintaining valid convex combinations in the observation space. The simplex constraint (non-negative weights summing to 1) ensures that each data point is represented as a proper weighted combination of row archetypes. Without this constraint, the model would lose its interpretability and might generate unrealistic representations.
The implementation includes numerical safeguards to prevent division by zero and ensure stable optimization even with extreme weight values.
- Parameters:
coefficients – Row coefficient matrix (n_samples, n_row_archetypes)
- Returns:
Projected coefficients satisfying simplex constraints (non-negative, sum to 1)
- reconstruct(X: ndarray | None = None) ndarray[source]
Reconstruct data from biarchetypal representation.
This method provides the inverse operation of transform(), reconstructing data points from their biarchetypal weights. This capability serves several critical purposes:
Validation of model quality through reconstruction error assessment
Interpretation of what specific archetypes represent in data space
Generation of synthetic data by manipulating archetype weights
Noise reduction by reconstructing data through dominant archetypes
The method handles both original training data and new data points.
- Parameters:
X – Optional data matrix to reconstruct. If None, uses the training data
- Returns:
Reconstructed data matrix in the original feature space
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') BiarchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- transform(X: ndarray, y: ndarray | None = None, **kwargs) Any[source]
Transform new data into dual-directional archetype space.
This method computes optimal weights to represent new data in terms of discovered archetypes, enabling consistent interpretation of new observations within the established biarchetypal framework.
Unlike conventional AA, this transform operates in both row and column spaces, providing a holistic representation of new data that preserves the model’s dual-perspective advantage. The implementation efficiently leverages pre-trained biarchetypes to avoid redundant computation.
- Parameters:
X – New data matrix (n_samples, n_features) to transform
y – Ignored. Present for scikit-learn API compatibility
**kwargs – Additional parameters for customizing transformation
- Returns:
Tuple of (row_weights, col_weights) representing the data in the biarchetypal space
archetypax.models.sparse_archetypes module
Sparse Archetypal Analysis: Interpretable pattern discovery with sparsity constraints.
This module extends archetypal analysis with sparsity-promoting regularization, enabling more interpretable and focused archetype discovery. By encouraging archetypes to utilize only essential features, this approach addresses a key limitation of standard archetypal analysis: the tendency to produce dense, difficult-to-interpret archetypes in high-dimensional spaces.
- class archetypax.models.sparse_archetypes.SparseArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]
Bases:
ImprovedArchetypalAnalysisArchetypal Analysis with sparsity constraints for enhanced interpretability.
This implementation addresses a fundamental challenge in standard archetypal analysis: dense archetypes that utilize many features are often difficult to interpret, particularly in high-dimensional datasets where most features may be irrelevant to specific patterns.
By incorporating sparsity constraints, this approach offers several key advantages:
More interpretable archetypes that focus on truly relevant features
Automatic feature selection within the archetypal framework
Improved robustness to noise and irrelevant dimensions
Better generalization by preventing overfitting to spurious correlations
Computationally efficient representations, especially for high-dimensional data
Multiple sparsity-promoting methods are supported, enabling adaptation to different data characteristics and interpretability requirements.
- __init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]
Initialize the Sparse Archetypal Analysis model.
- Parameters:
n_archetypes – Number of archetypes to discover - controls the model’s expressiveness and granularity of pattern discovery
max_iter – Maximum optimization iterations - higher values enable better convergence at computational cost
tol – Convergence tolerance for early stopping - smaller values yield more precise solutions but require more iterations
random_seed – Random seed for reproducibility across runs
learning_rate – Gradient descent step size - critical balance between convergence speed and stability
lambda_reg – Weight regularization strength - controls weight sparsity for better interpretability
lambda_sparsity – Archetype sparsity strength - higher values produce more focused archetypes using fewer features
sparsity_method – Technique for promoting archetype sparsity: - “l1”: L1 regularization (fastest, robust, tends to zero out features) - “l0_approx”: Approximated L0 regularization (more aggressive sparsity) - “feature_selection”: Entropy-based selection (focuses on key features)
normalize – Whether to normalize features - essential for data with different scales
projection_method – Method for projecting archetypes to convex hull: - “cbap”: Convex boundary approximation (default, most stable) - “convex_hull”: Exact convex hull vertices (more accurate) - “knn”: K-nearest neighbors approximation (faster for large datasets)
projection_alpha – Strength of boundary projection - higher values push archetypes more aggressively toward extremes
archetype_init_method – Initialization strategy for archetypes: - “directional”: Directions from data centroid (robust default) - “qhull”/”convex_hull”: Convex hull vertices (geometry-aware) - “kmeans”/”kmeans++”: K-means++ initialization (density-aware)
min_volume_factor – Minimum volume requirement for archetype simplex - prevents degenerate solutions with collapsed archetypes
**kwargs – Additional parameters including: - early_stopping_patience: Iterations with no improvement before stopping - verbose_level/logger_level: Controls logging detail
- diversify_archetypes(archetypes: Array, X: Array) Array[source]
Prevent degenerate solutions by ensuring sufficient archetype diversity.
This critical post-processing step addresses a fundamental challenge in archetypal analysis: optimization can sometimes converge to solutions where multiple archetypes collapse to similar positions, particularly when sparsity is enforced.
Such degenerate solutions drastically reduce model expressiveness and interpretability. This method actively counteracts this tendency by:
Detecting potential degeneracy through simplex volume measurement
Systematically pushing archetypes away from each other when needed
Ensuring archetypes remain valid (within the convex hull)
Verifying improvement through before/after volume comparison
This operation is performed outside the JAX-compiled update steps since it involves non-differentiable operations and contingent logic.
- Parameters:
archetypes – Current archetypes matrix to diversify
X – Data matrix defining the convex hull boundary
- Returns:
Diversified archetypes with improved distribution and volume
- fit(X: ndarray, normalize: bool = False, **kwargs) SparseArchetypalAnalysis[source]
Fit the model to discover sparse, interpretable archetypes.
This method orchestrates the complete sparse archetypal analysis process, building on the standard archetypal optimization while incorporating critical extensions for sparsity and stability:
Leverages the parent class for core optimization
Applies the selected sparsity-promoting method during optimization
Performs post-processing to ensure archetype diversity
Validates sparsity and geometric properties of the solution
The result is a set of archetypes that balance reconstruction fidelity, interpretability, and geometric meaningfulness.
- Parameters:
X – Data matrix (n_samples, n_features)
normalize – Whether to normalize features before fitting - essential for data with different scales
**kwargs – Additional parameters for customizing the fitting process
- Returns:
Self - fitted model instance with discovered sparse archetypes
- fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]
Fit the model and immediately transform the input data.
This convenience method combines model fitting and data transformation in a single operation, which offers two key advantages:
Computational efficiency by avoiding redundant calculations
Simplified workflow for immediate archetype-based representation
This method is particularly useful in analysis pipelines or when integrating with scikit-learn compatible frameworks that expect this pattern.
- Parameters:
X – Data matrix to fit and transform (n_samples, n_features)
y – Ignored. Present for scikit-learn API compatibility
normalize – Whether to normalize features before fitting
**kwargs – Additional parameters passed to fit()
- Returns:
Weight matrix representing each sample as a combination of the discovered sparse archetypes (n_samples, n_archetypes)
- get_archetype_sparsity() ndarray[source]
Calculate the effective sparsity of each archetype.
This diagnostic method provides a quantitative measure of how successfully the sparsity constraints have been applied to each archetype. Rather than simply counting zeros (which may be unsuitable for soft-thresholded approaches), it uses the Gini coefficient as a more nuanced sparsity metric.
The Gini coefficient measures the inequality among values, with higher values indicating greater sparsity (few large values, many small values). This provides:
A standardized way to compare archetypes’ feature utilization
A continuous measure that works with both hard and soft sparsity
A basis for identifying archetypes that may need further refinement
A metric for evaluating different sparsity methods
- Returns:
Array containing sparsity scores for each archetype (higher values indicate more focused archetypes using fewer features)
- loss_function(archetypes: Array, weights: Array, X: Array) Array[source]
Calculate the composite loss function incorporating sparsity constraints.
This enhanced objective function extends the standard archetypal loss with multiple sparsity-promoting regularization terms, balancing several competing objectives:
Reconstruction accuracy: Ensuring archetypes accurately represent the data
Archetype sparsity: Promoting focused archetypes that use fewer features
Weight interpretability: Encouraging sparse, distinctive weight patterns
Boundary alignment: Maintaining archetypes at meaningful data extremes
Archetype diversity: Preventing redundant or overlapping archetypes
The balance between these terms is critical - too much emphasis on sparsity may sacrifice reconstruction quality, while too little won’t yield the interpretability benefits.
- Parameters:
archetypes – Archetype matrix (n_archetypes, n_features)
weights – Weight matrix (n_samples, n_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Combined loss incorporating reconstruction error and multiple regularization terms
- set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') SparseArchetypalAnalysis
Request metadata passed to the
fitmethod.Note that this method is only relevant if
enable_metadata_routing=True(seesklearn.set_config()). Please see User Guide on how the routing mechanism works.The options for each parameter are:
True: metadata is requested, and passed tofitif provided. The request is ignored if metadata is not provided.False: metadata is not requested and the meta-estimator will not pass it tofit.None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
Note
This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a
Pipeline. Otherwise it has no effect.
- update_archetypes(archetypes: Array, weights: Array, X: Array) Array[source]
Update archetypes with sparsity promotion and degeneracy prevention.
This enhanced update method extends the standard approach with critical improvements for sparse archetypal analysis:
Sparsity promotion through the selected method (l1, l0, feature selection)
Variance-aware noise injection to prevent dimensional collapse
Constraint enforcement to maintain valid convex combinations
Boundary projection to ensure archetypes remain at meaningful extremes
These extensions are essential for learning truly interpretable archetypes while maintaining numerical stability and preventing convergence to degenerate solutions.
- Parameters:
archetypes – Current archetype matrix (n_archetypes, n_features)
weights – Weight matrix (n_samples, n_archetypes)
X – Data matrix (n_samples, n_features)
- Returns:
Updated archetypes incorporating sparsity and diversity constraints
Module contents
Core model implementations for Archetypal Analysis.
This module provides specialized implementations of Archetypal Analysis algorithms, each addressing specific analytical challenges and use cases. Archetypal Analysis discovers extreme patterns in data that serve as the building blocks for representing all observations as convex combinations.
- Available Models:
- ArchetypalAnalysis:
Foundational implementation suitable for low-dimensional datasets or initial exploration when computational efficiency matters
- ImprovedArchetypalAnalysis:
Enhanced version with advanced initialization strategies, robust optimization, and boundary projection techniques - recommended for most applications due to superior stability and convergence properties
- SparseArchetypalAnalysis:
Implementation enforcing feature sparsity in archetypes - essential for high-dimensional data where interpretability is a priority and feature selection is desirable
- BiarchetypalAnalysis:
Dual-direction analysis revealing patterns in both observation and feature spaces simultaneously - ideal for datasets where understanding relationships between features is as important as clustering observations
- Basic Usage:
from archetypax.models import ArchetypalAnalysis
model = ArchetypalAnalysis(n_archetypes=5) model.fit(data) archetypes = model.get_archetypes()