archetypax.models package

Submodules

archetypax.models.archetypes module

Improved Archetypal Analysis using JAX.

This module extends the base ArchetypalAnalysis class with enhanced optimization strategies and boundary projection techniques. The ImprovedArchetypalAnalysis class provides a more robust and versatile implementation of Archetypal Analysis (AA) using JAX for GPU acceleration.

The improvements focus on: - Multiple initialization strategies (directional, convex hull, kmeans++) - Advanced optimization with hybrid gradient and direct update methods - Adaptive boundary projection techniques - Better convergence stability through regularization

Key advantages over the base implementation: - More stable convergence across diverse datasets - Higher quality solutions with improved boundary placement - Richer configuration options for domain-specific tuning - Enhanced computational efficiency for large-scale applications

Example usage:

```python from archetypax.models import ImprovedArchetypalAnalysis

# Initialize model model = ImprovedArchetypalAnalysis(

n_archetypes=5, normalize=True, archetype_init_method=”directional”, projection_method=”cbap”

)

# Fit model and transform data weights = model.fit_transform(X)

# Extract discovered archetypes archetypes = model.archetypes ```

class archetypax.models.archetypes.ImprovedArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]

Bases: ArchetypalAnalysis

Improved Archetypal Analysis model using JAX.

__init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', **kwargs)[source]

Initialize an enhanced archetypal analysis model with robust optimization.

This improved implementation addresses key limitations of standard archetypal analysis through advanced initialization strategies, robust gradient-based optimization, and adaptive boundary projection techniques. These enhancements significantly improve convergence stability, solution quality, and computational efficiency across diverse datasets.

Parameters:
  • n_archetypes – Number of archetypes to discover - controls the model’s expressiveness and dimensionality reduction ratio. Higher values capture more nuanced patterns at the cost of interpretability and potential overfitting.

  • max_iter – Maximum optimization iterations - higher values ensure better convergence at the cost of computational time. The default (500) balances solution quality with reasonable runtime for most datasets.

  • tol – Convergence tolerance - smaller values yield more precise solutions but require more iterations. The default (1e-6) is suitable for most applications, while scientific applications may require smaller values.

  • random_seed – Random seed for reproducibility - ensures consistent results across runs with the same data and parameters.

  • learning_rate – Gradient descent step size - critical parameter balancing convergence speed with stability. Too high risks overshooting minima, while too low causes slow convergence.

  • lambda_reg – Regularization strength for weights - controls the balance between reconstruction accuracy and weight sparsity. Higher values promote more discrete archetype assignments.

  • normalize – Whether to normalize features - essential when features have different scales to prevent dominance by high-magnitude features. Should be True for most real-world datasets.

  • projection_method

    Strategy for projecting archetypes to boundary: - “cbap” (default):

    Convex boundary approximation projection - balanced approach suitable for most datasets

    • ”convex_hull”:

      Uses exact convex hull vertices - more precise but computationally intensive for high dimensions

    • ”knn”:

      K-nearest neighbors approximation - faster for large datasets

  • projection_alpha – Projection strength parameter (0-1) - controls how aggressively archetypes are pushed toward the boundary. Higher values emphasize extremeness over reconstruction.

  • archetype_init_method

    Initialization strategy for archetypes: - “directional” (default):

    Directions from centroid - robust general-purpose approach that balances diversity with boundary alignment

    • ”qhull”/”convex_hull”:

      Exact convex hull vertices - ideal when geometric extremes are well-defined

    • ”kmeans”/”kmeans++”:

      K-means++ initialization - beneficial when density-based initialization aligns with domain expectations

  • **kwargs

    Additional parameters: - early_stopping_patience: Iterations without improvement before stopping - verbose_level: Controls logging detail (0-4)

    • 0: Critical only

    • 1: Error level

    • 2: Warning level

    • 3: Info level (recommended for monitoring)

    • 4: Debug level (verbose training details)

    • logger_level: Alternative to verbose_level with reversed mapping

directional_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]

Generate directions using points that are evenly distributed on a sphere.

Parameters:
  • X_jax – Data matrix of shape (n_samples, n_features)

  • n_samples – Number of samples

  • n_features – Number of features

Returns:

Archetypes and archetype indices

fit(X: ndarray, normalize: bool = False, **kwargs) ImprovedArchetypalAnalysis[source]

Discover optimal archetypes through advanced multi-strategy optimization.

This core method identifies the extreme patterns that define the convex hull of the data and serve as the building blocks for representing all observations. The implementation features several critical enhancements:

  1. Intelligent initialization strategies that target promising positions

  2. Hybrid optimization combining gradient-based and direct algebraic updates

  3. Adaptive boundary projection to ensure archetypes represent true extremes

  4. Improved numerical stability through strategic regularization

  5. Early stopping logic to prevent overfitting and wasted computation

These techniques collectively address the fundamental challenges of archetypal analysis:

sensitivity to initialization, convergence to suboptimal solutions, and computational efficiency. and computational efficiency.

Parameters:
  • X – Data matrix to analyze (n_samples, n_features)

  • normalize – Whether to normalize features before fitting - essential for data with features of different scales or magnitudes

  • **kwargs – Additional optimization parameters: - early_stopping_patience: Iterations without improvement before stopping (higher values ensure convergence at computational cost) - additional parameters specific to the initialization method

Returns:

Self - fitted model instance with discovered archetypes

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]

Fit the model and immediately transform the input data.

This convenience method combines model fitting and data transformation in a single operation, which offers two key advantages:

  1. Computational efficiency by avoiding redundant calculations

  2. Simplified workflow for immediate archetypal representation

This method is particularly valuable in analysis pipelines or when integrating with scikit-learn compatible frameworks that expect this pattern. It ensures that the transformation is performed with the same preprocessing settings used during fitting.

Parameters:
  • X – Data matrix to fit and transform (n_samples, n_features)

  • y – Ignored, present for scikit-learn API compatibility

  • normalize – Whether to normalize features before fitting - essential for data with different scales or magnitudes

  • **kwargs – Additional parameters passed to both fit() and transform(), including optimization settings and convergence criteria

Returns:

Weight matrix representing each sample as a combination of the discovered archetypes (n_samples, n_archetypes)

kmeans_pp_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]

More efficient k-means++ style initialization using JAX.

Parameters:
  • X_jax – Data matrix of shape (n_samples, n_features)

  • n_samples – Number of samples

  • n_features – Number of features

Returns:

Archetypes and archetype indices

loss_function(archetypes: Array, weights: Array, X: Array) Array[source]

Composite objective function balancing reconstruction with interpretability.

This carefully designed loss function guides the optimization process by balancing multiple competing objectives essential for archetypal analysis: 1. Reconstruction fidelity: Ensuring archetypes accurately represent the data 2. Weight interpretability: Encouraging sparse, distinctive weight patterns 3. Boundary alignment: Promoting archetypes at meaningful extremal positions

The weighted combination of these terms creates a landscape that guides optimization toward solutions with both mathematical validity (convex hull representation) and practical utility (interpretable patterns).

The relative weighting of these components is critical to achieving the right balance between reconstruction accuracy and archetypal properties.

Parameters:
  • archetypes – Candidate archetype matrix (n_archetypes, n_features)

  • weights – Weight matrix (n_samples, n_archetypes) describing how to represent each sample as a combination of archetypes

  • X – Original data matrix (n_samples, n_features) to reconstruct

Returns:

Scalar loss value combining reconstruction error with regularization terms - lower values indicate better solutions

project_archetypes(archetypes: Array, X: Array) Array[source]

Strategically position archetypes on the convex hull boundary for optimal representation.

This method is critical for meaningful archetypal analysis as it ensures archetypes remain at the extremes of the data distribution where they best represent distinctive patterns.

Our implementation differs from standard projection methods by: 1. Projecting along meaningful directions from the data centroid 2. Identifying precise extreme points rather than using approximate methods 3. Blending original positions with boundary points for stability 4. Applying adaptive adjustments based on current position

Parameters:
  • archetypes – Current archetype positions (n_archetypes, n_features)

  • X – Data matrix defining the convex hull (n_samples, n_features)

Returns:

Projected archetypes strategically positioned at or near the convex hull boundary

project_archetypes_convex_hull(archetypes: Array, X: Array) Array[source]

Alternative archetype projection that uses convex combinations of extreme points.

This method identifies potential extreme points and creates archetypes as sparse convex combinations of these points, ensuring they lie on the boundary.

Technical details: - Multi-directional Exploration:

Generates multiple random directions around the main archetype direction, allowing for more diverse extreme point discovery.

  • Sparse Convex Combinations:

    Creates archetypes as weighted combinations of extreme points found in different directions, with emphasis on the main direction.

  • Boundary Positioning:

    By using convex combinations of extreme points, archetypes are positioned on or near the convex hull boundary rather than in its interior.

This approach offers potentially better exploration of the convex hull boundary at the cost of slightly higher computational complexity.

Parameters:
  • archetypes – Current archetype matrix of shape (n_archetypes, n_features)

  • X – Original data matrix of shape (n_samples, n_features)

Returns:

Projected archetype matrix positioned on the convex hull boundary

project_archetypes_knn(archetypes: Array, X: Array) Array[source]

Original k-NN based archetype projection (kept for comparison).

This method tends to pull archetypes inside the convex hull due to its averaging nature, which is suboptimal for archetypal analysis where archetypes should ideally lie on the convex hull boundary.

Parameters:
  • archetypes – Current archetype matrix

  • X – Original data matrix

Returns:

Projected archetype matrix (typically positioned inside the convex hull)

project_weights(weights: Array) Array[source]

JIT-compiled weight projection function.

Parameters:

weights – Weight matrix of shape (n_samples, n_archetypes)

Returns:

Projected weight matrix of shape (n_samples, n_archetypes)

qhull_init(X_jax: Array, n_samples: int, n_features: int) tuple[Array, Array][source]

Initialize archetypes using convex hull vertices via QHull algorithm.

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ImprovedArchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]

Transform data into archetypal weight space with adaptive optimization.

This method computes optimal weights representing each sample as a convex combination of discovered archetypes. The transformation reveals how samples relate to extreme patterns, offering:

  1. Dimensionality reduction while preserving interpretability

  2. Soft clustering based on meaningful archetypes rather than arbitrary centroids

  3. Insights into sample composition and relationship to extreme patterns

  4. A foundation for transfer learning when applying archetypes to new data

Multiple optimization strategies are available, with adaptive selection based on dataset size to balance computational efficiency with solution quality.

Parameters:
  • X – Data matrix to transform (n_samples, n_features)

  • y – Ignored, present for scikit-learn API compatibility

  • **kwargs

    Additional parameters: - method: Optimization approach to use:

    • ”lbfgs”: Best for small datasets (<1000 samples)

    • ”adam”: Balanced option for mid-sized data (default)

    • ”sgd”: Memory-efficient for large datasets

    • ”adaptive”: Automatically selects based on data size

    • max_iter: Maximum iterations for weight optimization

    • tol: Convergence tolerance (smaller values for more precision)

Returns:

Weight matrix representing each sample as a combination of the discovered archetypes (n_samples, n_archetypes)

update_archetypes(archetypes: Array, weights: Array, X: Array) Array[source]

Alternative archetype update strategy based on weighted reconstruction.

This approach directly optimizes archetypes by computing the pseudo-inverse of weights, which often provides a more targeted and mathematically sound update than gradient descent for this specific subproblem.

Parameters:
  • archetypes – Archetype matrix of shape (n_archetypes, n_features)

  • weights – Weight matrix of shape (n_samples, n_archetypes)

  • X – Data matrix of shape (n_samples, n_features)

Returns:

Updated archetype matrix of shape (n_archetypes, n_features)

archetypax.models.base module

Archetypal Analysis using JAX.

This module provides the foundational implementation of Archetypal Analysis (AA) optimized for GPU acceleration via JAX. It serves as the base class for more advanced implementations in the archetypax package.

Archetypal Analysis identifies extreme patterns (archetypes) in data that can represent the entire dataset through convex combinations, offering both dimensionality reduction and interpretable insights into data structure.

Core Features: - JAX-based implementation for GPU/TPU acceleration - Scikit-learn compatible API (BaseEstimator, TransformerMixin) - Standard k-means++ style initialization - Gradient-based optimization with Adam - Basic weight and archetype projection methods

This base implementation provides a solid foundation with standard AA features, while more advanced techniques are available in derived classes such as ImprovedArchetypalAnalysis.

Example usage:

```python from archetypax.models import ArchetypalAnalysis

# Initialize model model = ArchetypalAnalysis(

n_archetypes=5, normalize=True

)

# Fit model and transform data weights = model.fit_transform(X) ```

class archetypax.models.base.ArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Bases: BaseEstimator, TransformerMixin

GPU-accelerated Archetypal Analysis implementation using JAX.

This class provides the core functionality for identifying archetypes - extreme points that can represent data through convex combinations,

offering interpretable and meaningful insights into data structure.

Leverages JAX for efficient GPU computation and automatic differentiation.

__init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, normalize: bool = False, **kwargs)[source]

Initialize the Archetypal Analysis model.

Parameters:
  • n_archetypes – Number of archetypes to find - determines the dimensionality of the representation space

  • max_iter – Maximum number of iterations for optimization convergence

  • tol – Convergence tolerance for early stopping

  • random_seed – Random seed for reproducible results

  • learning_rate – Learning rate for optimizer - lower values provide better stability at the cost of slower convergence

  • normalize – Whether to normalize the data before fitting - essential for features with different scales

  • **kwargs

    Additional keyword arguments including: early_stopping_patience:

    Number of iterations without improvement before stopping optimization

    logger_level/verbose_level:

    Control for logging granularity

fit(X: ndarray, normalize: bool = False, **kwargs) ArchetypalAnalysis[source]

Fit the model to the data.

Identifies optimal archetypes and weights through iterative optimization. Uses Adam optimizer with projection steps to ensure constraints are satisfied.

Parameters:
  • X – Data matrix (n_samples, n_features)

  • normalize – Whether to normalize the data before fitting

  • **kwargs – Additional keyword arguments for fine-tuning the fitting process

Returns:

Self - fitted model instance

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False) ndarray[source]

Fit the model and transform the data.

Parameters:
  • X – Data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • normalize – Whether to normalize the data before fitting.

Returns:

Weight matrix representing each sample as a combination of archetypes

get_loss_history() list[float][source]

Get the loss history from training.

Returns:

List of loss values recorded during fitting

loss_function(archetypes: Array, weights: Array, X: Array) Array[source]

Calculate reconstruction loss with entropy regularization.

Computes the fundamental objective: minimize reconstruction error while encouraging more discriminative weights through entropy regularization.

Parameters:
  • archetypes – Archetype matrix (n_archetypes, n_features)

  • weights – Weight matrix (n_samples, n_archetypes)

  • X – Data matrix (n_samples, n_features)

Returns:

Combined loss value as a scalar

project_archetypes(archetypes: Array, X: Array) Array[source]

Project archetypes using soft assignment based on k-nearest neighbors.

Ensures archetypes remain within the convex hull of data points by creating soft assignments based on proximity. This approach offers better stability than hard assignment methods.

Parameters:
  • archetypes – Archetype matrix (n_archetypes, n_features)

  • X – Original data matrix (n_samples, n_features)

Returns:

Projected archetype matrix (n_archetypes, n_features)

project_weights(weights: Array) Array[source]

Project weights to satisfy simplex constraints with numerical stability.

Ensures that weights form valid convex combinations (non-negative and sum to 1) while avoiding numerical underflow/overflow issues.

Parameters:

weights – Weight matrix (n_samples, n_archetypes)

Returns:

Projected weight matrix (n_samples, n_archetypes)

reconstruct(X: ndarray | None = None) ndarray[source]

Reconstruct data using the learned archetypes.

Parameters:

X – Data to reconstruct, or None to use the training data

Returns:

Reconstructed data of shape (n_samples, n_features)

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') ArchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

transform(X: ndarray, y: ndarray | None = None, **kwargs) ndarray[source]

Transform new data to archetype weights.

Parameters:
  • X – New data matrix of shape (n_samples, n_features)

  • y – Ignored. Present for API consistency by convention.

  • **kwargs – Additional keyword arguments for the transform method.

Returns:

Weight matrix representing each sample as a combination of archetypes

archetypax.models.biarchetypes module

Biarchetypal Analysis using JAX.

This module extends ImprovedArchetypalAnalysis to perform dual-directional pattern discovery, identifying archetypes in both observation space (rows) and feature space (columns) simultaneously. While traditional archetypal analysis only finds patterns in one direction, biarchetypal analysis provides a richer understanding by decomposing data as: X ≃ alpha·beta·X·theta·gamma

Core Features: - Discovers extreme patterns in both observations and features - Reveals cross-modal relationships between row and column archetypes - Creates more interpretable representation via biarchetypes - Handles complex data with interdependent row and column structures

The four-factor decomposition offers deeper insights than traditional methods by capturing how observation patterns interact with feature patterns throughout the data.

Example usage:

```python from archetypax.models import BiarchetypalAnalysis

# Initialize model with separate row and column archetype counts model = BiarchetypalAnalysis(

n_row_archetypes=4, n_col_archetypes=3, projection_method=”default”, normalize=True

)

# Fit model and get dual-directional representations row_weights, col_weights = model.fit_transform(X)

# Extract bi-archetypes matrix (core patterns) biarchetypes = model.get_biarchetypes() ```

Based on Alcacer et al., “Biarchetype analysis: simultaneous learning of observations and features based on extremes.”

class archetypax.models.biarchetypes.BiarchetypalAnalysis(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]

Bases: ImprovedArchetypalAnalysis

Biarchetypal Analysis for dual-directional pattern discovery.

This implementation extends archetypal analysis to simultaneously identify extreme patterns in both observations (rows) and features (columns), offering a richer understanding of data structure. Traditional archetypal analysis only identifies patterns in observation space, missing crucial feature-level insights.

By factorizing the data matrix X as: X ≃ alpha·beta·X·theta·gamma

BA provides several advantages: - Captures both observation-level and feature-level patterns - Enables cross-modal analysis between observations and features - Creates a more compact and interpretable representation via biarchetypes - Reveals latent relationships that single-directional methods cannot detect

This implementation is based on the work by Alcacer et al., “Biarchetype analysis: simultaneous learning of observations and features based on extremes.”

__init__(n_row_archetypes: int, n_col_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, projection_method: str = 'default', lambda_reg: float = 0.01, **kwargs)[source]

Initialize the Biarchetypal Analysis model.

Parameters:
  • n_row_archetypes – Number of row archetypes - controls expressiveness in observation space (rows)

  • n_col_archetypes – Number of column archetypes - controls expressiveness in feature space (columns)

  • max_iter – Maximum optimization iterations - higher values enable better convergence at computational cost

  • tol – Convergence tolerance for early stopping - smaller values yield more precise solutions but require more iterations

  • random_seed – Random seed for reproducibility across runs

  • learning_rate – Gradient descent step size - critical balance between convergence speed and stability

  • projection_method – Method for projecting archetypes to extreme points: - “default” uses convex boundary approximation - “convex_hull” uses convex hull approximation - “knn” uses k-nearest neighbors approximation

  • lambda_reg – Regularization strength - controls sparsity/smoothness tradeoff n archetype weights

  • **kwargs – Additional parameters including: - early_stopping_patience: Iterations with no improvement before stopping - verbose_level/logger_level: Controls logging detail

fit(X: ndarray, normalize: bool = False, **kwargs) BiarchetypalAnalysis[source]

Fit the Biarchetypal Analysis model to identify dual-perspective archetypes.

This core method performs the four-factor decomposition of the data matrix, simultaneously discovering patterns in observation and feature spaces. The implementation employs advanced optimization strategies including:

  1. Sophisticated initialization for both row and column factors

  2. Adaptive learning rate scheduling for stable convergence

  3. Specialized projection operations to maintain meaningful boundaries

  4. Careful numerical handling to prevent instability

  5. Early stopping with convergence monitoring

These optimizations are essential due to the complexity of the four-factor model, which is more challenging to optimize than standard Archetypal Analysis.

Parameters:
  • X – Data matrix (n_samples, n_features)

  • normalize – Whether to normalize features - essential for data with different scales

  • **kwargs – Additional parameters for customizing the fitting process

Returns:

Self - fitted model instance with discovered biarchetypes

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) Any[source]

Fit the model and transform the data in a single operation.

This convenience method combines model fitting and data transformation in a single step, offering two key advantages: 1. Computational efficiency by avoiding redundant calculations 2. Simplified workflow for immediate biarchetypal representation

The method is particularly valuable when the biarchetypal representation is needed immediately after fitting, such as in analysis pipelines or when integrating with scikit-learn compatible frameworks.

Parameters:
  • X – Data matrix to fit and transform (n_samples, n_features)

  • y – Ignored. Present for scikit-learn API compatibility

  • normalize – Whether to normalize features before fitting

  • **kwargs – Additional parameters passed to fit()

Returns:

Tuple of (row_weights, col_weights) representing the data in biarchetypal space

get_all_archetypes() tuple[ndarray, ndarray][source]

Retrieve both row and column archetypes in a single call.

This convenience method provides access to both directions of archetypal analysis simultaneously, facilitating comprehensive analysis and visualization of the dual-perspective patterns. Accessing both archetypes together is particularly valuable for cross-modal analysis examining relationships between observation patterns and feature patterns.

Returns:

Tuple of (row_archetypes, column_archetypes) matrices

get_all_weights() tuple[ndarray, ndarray][source]

Retrieve both row and column weights in a single call.

This convenience method provides access to all weight coefficients simultaneously, enabling comprehensive analysis of how observations and features relate to their respective archetypes. This unified view is particularly valuable for understanding the full biarchetypal decomposition and how information flows between the row and column spaces.

Returns:

Tuple of (row_weights, column_weights) matrices

get_biarchetypes() ndarray[source]

Retrieve the core biarchetypes matrix.

The biarchetypes matrix (Z = β·X·θ) represents the heart of the model, capturing the essential patterns at the intersection of row and column archetypes.

This matrix provides a compact representation of the data’s underlying structure, with each element representing a specific row-column archetype interaction.

Access to this matrix is essential for visualization, interpretation, and advanced analysis of the identified patterns.

Returns:

Biarchetypes matrix (n_row_archetypes, n_col_archetypes)

get_col_archetypes() ndarray[source]

Retrieve the column archetypes.

Column archetypes represent extreme patterns in feature space, describing distinctive feature combinations or “feature types.”

This perspective is unique to biarchetypal analysis and provides crucial insights about feature relationships that would be missed in standard archetypal analysis.

These archetypes enable feature-level interpretations and can reveal coordinated feature behaviors across different data contexts.

Returns:

Column archetypes matrix (n_col_archetypes, n_features)

get_col_weights() ndarray[source]

Retrieve the column coefficients (gamma).

Column weights represent how each feature is composed as a mixture of column archetypes. These weights provide unique insights into:

  1. Which feature patterns are expressed in each original feature

  2. How features group together based on shared archetype influence

  3. Feature importance through the lens of archetypal patterns

  4. Potential redundancies in the feature space

This feature-space perspective is a distinguishing advantage of biarchetypal analysis compared to standard archetypal methods.

Returns:

Column weight matrix (n_col_archetypes, n_features)

get_row_archetypes() ndarray[source]

Retrieve the row archetypes.

Row archetypes represent extreme patterns in observation space, describing distinctive types of data points. These archetypes are essential for understanding the primary modes of variation among observations and provide the foundation for interpreting data point weights.

In the biarchetypal model, row archetypes are projections of the data matrix via the beta coefficients (β·X).

Returns:

Row archetypes matrix (n_row_archetypes, n_features)

get_row_weights() ndarray[source]

Retrieve the row coefficients (alpha).

Row weights represent how each data point is composed as a mixture of row archetypes. These weights are essential for: 1. Understanding which archetype patterns dominate each observation 2. Clustering similar observations based on their archetype compositions 3. Detecting anomalies as points with unusual archetype weights 4. Creating reduced-dimension visualizations based on archetype space

The weights are constrained to be non-negative and sum to 1 (simplex constraint), making them directly interpretable as proportions.

Returns:

Row weight matrix (n_samples, n_row_archetypes)

loss_function(params: dict[str, Array], X: Array) Array[source]

Calculate the composite reconstruction loss for biarchetypal factorization.

This core objective function balances reconstruction quality with sparsity promotion to ensure interpretable representations. Unlike standard AA, the biarchetypal loss operates on a four-factor decomposition, requiring careful numerical handling to prevent instability during optimization.

The loss promotes three key properties: 1. Accurate data reconstruction through the biarchetypal representation 2. Sparse coefficients for interpretable patterns 3. Numerical stability through explicit type control

Parameters:
  • params – Dictionary containing the four model matrices: - alpha: Row coefficients (n_samples, n_row_archetypes) - beta: Row archetypes (n_row_archetypes, n_samples) - theta: Column archetypes (n_features, n_col_archetypes) - gamma: Column coefficients (n_col_archetypes, n_features)

  • X – Data matrix (n_samples, n_features)

Returns:

Combined loss value incorporating reconstruction and regularization terms

project_col_archetypes(archetypes: Array, X: Array) Array[source]

Project column archetypes to the boundary of the feature space.

This critical counterpart to row archetype projection ensures column archetypes represent distinct feature patterns. While conceptually similar to row projection, this operation works in the transposed space, treating features as observations and finding extremes among them.

Without this specialized projection, the feature archetypes would not capture meaningful feature combinations, undermining the dual-perspective advantage of biarchetypal analysis.

The implementation: 1. Transposes the problem to work in feature space 2. Identifies feature combinations that represent extremes 3. Creates boundary points through weighted feature combinations 4. Maintains numerical stability throughout

Parameters:
  • archetypes – Column archetype matrix (n_features, n_col_archetypes)

  • X – Data matrix (n_samples, n_features)

Returns:

Projected column archetypes positioned at the boundaries of feature space, representing distinct feature patterns

project_col_coefficients(coefficients: Array) Array[source]

Project column coefficients to satisfy simplex constraints.

This projection enforces valid convex combinations in the feature space, which differs critically from row coefficient projection. Feature weights must sum to 1 across columns (not rows), ensuring each feature is properly represented by column archetypes.

This axis-specific projection is a key distinction between standard AA and biarchetypal analysis, enabling the dual-directional nature of the model.

Parameters:

coefficients – Column coefficient matrix (n_col_archetypes, n_features)

Returns:

Projected coefficients with each feature’s weights summing to 1, maintaining valid convex combinations in feature space

project_row_archetypes(archetypes: Array, X: Array) Array[source]

Project row archetypes to the convex hull boundary of data points.

This critical operation ensures row archetypes remain at meaningful extremes of the observation space, where they represent distinct, interpretable patterns. Without this projection, archetypes would tend to collapse toward the data centroid during optimization, losing their representative power.

The implementation uses an adaptive multi-point boundary approximation that: 1. Identifies extreme directions from the data centroid 2. Selects multiple boundary points along each direction 3. Creates weighted combinations that maximize distinctiveness 4. Maintains numeric stability throughout the process

Parameters:
  • archetypes – Row archetype matrix (n_row_archetypes, n_samples)

  • X – Data matrix (n_samples, n_features)

Returns:

Projected row archetypes positioned at meaningful boundaries of the data’s convex hull

project_row_coefficients(coefficients: Array) Array[source]

Project row coefficients to satisfy simplex constraints.

This projection is essential for maintaining valid convex combinations in the observation space. The simplex constraint (non-negative weights summing to 1) ensures that each data point is represented as a proper weighted combination of row archetypes. Without this constraint, the model would lose its interpretability and might generate unrealistic representations.

The implementation includes numerical safeguards to prevent division by zero and ensure stable optimization even with extreme weight values.

Parameters:

coefficients – Row coefficient matrix (n_samples, n_row_archetypes)

Returns:

Projected coefficients satisfying simplex constraints (non-negative, sum to 1)

reconstruct(X: ndarray | None = None) ndarray[source]

Reconstruct data from biarchetypal representation.

This method provides the inverse operation of transform(), reconstructing data points from their biarchetypal weights. This capability serves several critical purposes: 1. Validation of model quality through reconstruction error assessment 2. Interpretation of what specific archetypes represent in data space 3. Generation of synthetic data by manipulating archetype weights 4. Noise reduction by reconstructing data through dominant archetypes

The method handles both original training data and new data points.

Parameters:

X – Optional data matrix to reconstruct. If None, uses the training data

Returns:

Reconstructed data matrix in the original feature space

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') BiarchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

transform(X: ndarray, y: ndarray | None = None, **kwargs) Any[source]

Transform new data into dual-directional archetype space.

This method computes optimal weights to represent new data in terms of discovered archetypes, enabling consistent interpretation of new observations within the established biarchetypal framework.

Unlike conventional AA, this transform operates in both row and column spaces, providing a holistic representation of new data that preserves the model’s dual-perspective advantage. The implementation efficiently leverages pre-trained biarchetypes to avoid redundant computation.

Parameters:
  • X – New data matrix (n_samples, n_features) to transform

  • y – Ignored. Present for scikit-learn API compatibility

  • **kwargs – Additional parameters for customizing transformation

Returns:

Tuple of (row_weights, col_weights) representing the data in the biarchetypal space

archetypax.models.sparse_archetypes module

Sparse Archetypal Analysis using JAX.

This module extends ImprovedArchetypalAnalysis with sparsity-promoting regularization, enabling more interpretable and focused archetype discovery. By encouraging archetypes to utilize only essential features, this approach addresses a key limitation of standard archetypal analysis: the tendency to produce dense, difficult-to-interpret archetypes in high-dimensional spaces.

Core Features: - Creates more interpretable archetypes focusing on truly relevant features - Performs automatic feature selection within the archetypal framework - Improves robustness to noise and irrelevant dimensions - Prevents overfitting to spurious correlations - Generates computationally efficient sparse representations

Multiple sparsity techniques are supported: - “l1”: L1 regularization (fastest, robust, tends to zero out features) - “l0_approx”: Approximated L0 regularization (more aggressive sparsity) - “feature_selection”: Entropy-based selection (focuses on key features)

Example usage:

```python from archetypax.models import SparseArchetypalAnalysis

# Initialize model with sparsity parameters model = SparseArchetypalAnalysis(

n_archetypes=5, lambda_sparsity=0.1, sparsity_method=”l1”, normalize=True

)

# Fit model and transform data weights = model.fit_transform(X)

# Evaluate archetype sparsity sparsity_scores = model.get_archetype_sparsity() ```

class archetypax.models.sparse_archetypes.SparseArchetypalAnalysis(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]

Bases: ImprovedArchetypalAnalysis

Archetypal Analysis with sparsity constraints for enhanced interpretability.

This implementation addresses a fundamental challenge in standard archetypal analysis: dense archetypes that utilize many features are often difficult to interpret, particularly in high-dimensional datasets where most features may be irrelevant to specific patterns.

By incorporating sparsity constraints, this approach offers several key advantages: 1. More interpretable archetypes that focus on truly relevant features 2. Automatic feature selection within the archetypal framework 3. Improved robustness to noise and irrelevant dimensions 4. Better generalization by preventing overfitting to spurious correlations 5. Computationally efficient representations, especially for high-dimensional data

Multiple sparsity-promoting methods are supported, enabling adaptation to different data characteristics and interpretability requirements.

__init__(n_archetypes: int, max_iter: int = 500, tol: float = 1e-06, random_seed: int = 42, learning_rate: float = 0.001, lambda_reg: float = 0.01, lambda_sparsity: float = 0.1, sparsity_method: str = 'l1', normalize: bool = False, projection_method: str = 'cbap', projection_alpha: float = 0.1, archetype_init_method: str = 'directional', min_volume_factor: float = 0.001, **kwargs)[source]

Initialize the Sparse Archetypal Analysis model.

Parameters:
  • n_archetypes – Number of archetypes to discover

  • max_iter – Maximum optimization iterations

  • tol – Convergence tolerance for early stopping

  • random_seed – Random seed for reproducibility

  • learning_rate – Gradient descent step size

  • lambda_reg – Weight regularization strength

  • lambda_sparsity – Archetype sparsity strength

  • sparsity_method – Technique for promoting archetype sparsity (“l1”, “l0_approx”, “feature_selection”)

  • normalize – Whether to normalize features

  • projection_method – Method for projecting archetypes to convex hull

  • projection_alpha – Strength of boundary projection

  • archetype_init_method – Initialization strategy for archetypes

  • min_volume_factor – Minimum volume requirement for archetype simplex

  • **kwargs – Additional parameters

diversify_archetypes(archetypes: Array, X: Array) Array[source]

Prevent degenerate solutions by ensuring sufficient archetype diversity.

Addresses a challenge where multiple archetypes can collapse to similar positions, especially when sparsity is enforced, reducing model expressiveness.

This method: 1. Detects potential degeneracy through simplex volume measurement 2. Systematically pushes archetypes away from each other when needed 3. Ensures archetypes remain valid (within the convex hull) 4. Verifies improvement through before/after volume comparison

Performed outside JAX-compiled steps due to non-differentiable operations.

Parameters:
  • archetypes – Current archetypes matrix to diversify

  • X – Data matrix defining the convex hull boundary

Returns:

Diversified archetypes with improved distribution and volume

fit(X: ndarray, normalize: bool = False, **kwargs) SparseArchetypalAnalysis[source]

Fit the model to discover sparse, interpretable archetypes.

Orchestrates the complete sparse archetypal analysis process: 1. Leverages the parent class for core optimization 2. Applies the selected sparsity-promoting method during optimization 3. Performs post-processing to ensure archetype diversity 4. Validates sparsity and geometric properties of the solution

Parameters:
  • X – Data matrix (n_samples, n_features)

  • normalize – Whether to normalize features before fitting

  • **kwargs – Additional parameters for customizing the fitting process

Returns:

Self - fitted model instance with discovered sparse archetypes

fit_transform(X: ndarray, y: ndarray | None = None, normalize: bool = False, **kwargs) ndarray[source]

Fit the model and transform the input data in one operation.

Combines model fitting and data transformation for: 1. Computational efficiency 2. Simplified workflow integration with scikit-learn compatible frameworks

Parameters:
  • X – Data matrix to fit and transform (n_samples, n_features)

  • y – Ignored. Present for scikit-learn API compatibility

  • normalize – Whether to normalize features before fitting

  • **kwargs – Additional parameters passed to fit()

Returns:

Weight matrix representing samples as combinations of sparse archetypes

get_archetype_sparsity() ndarray[source]

Calculate the effective sparsity of each archetype.

Uses the Gini coefficient rather than simply counting zeros, providing a more nuanced sparsity metric that works with both hard and soft thresholding.

The Gini coefficient measures inequality among values, with higher values indicating greater sparsity (few large values, many small values).

Returns:

Array of sparsity scores (higher values = more focused archetypes)

loss_function(archetypes: Array, weights: Array, X: Array) Array[source]

Calculate the composite loss function incorporating sparsity constraints.

The balance between multiple terms is critical: - Reconstruction accuracy: Ensuring archetypes accurately represent the data - Archetype sparsity: Promoting focused archetypes that use fewer features - Weight interpretability: Encouraging sparse, distinctive weight patterns - Boundary alignment: Maintaining archetypes at meaningful data extremes - Archetype diversity: Preventing redundant or overlapping archetypes

Parameters:
  • archetypes – Archetype matrix (n_archetypes, n_features)

  • weights – Weight matrix (n_samples, n_archetypes)

  • X – Data matrix (n_samples, n_features)

Returns:

Combined loss incorporating reconstruction error and multiple regularization terms

set_fit_request(*, normalize: bool | None | str = '$UNCHANGED$') SparseArchetypalAnalysis

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see sklearn.set_config()). Please see User Guide on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

Added in version 1.3.

Note

This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a Pipeline. Otherwise it has no effect.

Parameters:

normalize (str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED) – Metadata routing for normalize parameter in fit.

Returns:

self – The updated object.

Return type:

object

update_archetypes(archetypes: Array, weights: Array, X: Array) Array[source]

Update archetypes with sparsity promotion and degeneracy prevention.

This enhanced update method extends the standard approach with critical improvements for sparse archetypal analysis:

  1. Sparsity promotion through the selected method (l1, l0, feature selection)

  2. Variance-aware noise injection to prevent dimensional collapse

  3. Constraint enforcement to maintain valid convex combinations

  4. Boundary projection to ensure archetypes remain at meaningful extremes

Parameters:
  • archetypes – Current archetype matrix (n_archetypes, n_features)

  • weights – Weight matrix (n_samples, n_archetypes)

  • X – Data matrix (n_samples, n_features)

Returns:

Updated archetypes incorporating sparsity and diversity constraints

Module contents

Core model implementations for Archetypal Analysis.

This module provides specialized implementations of Archetypal Analysis algorithms, each addressing specific analytical challenges and use cases. Archetypal Analysis discovers extreme patterns in data that serve as the building blocks for representing all observations as convex combinations.

Available Models:
ArchetypalAnalysis:

Foundational implementation suitable for low-dimensional datasets or initial exploration when computational efficiency matters

ImprovedArchetypalAnalysis:

Enhanced version with advanced initialization strategies, robust optimization, and boundary projection techniques - recommended for most applications due to superior stability and convergence properties

SparseArchetypalAnalysis:

Implementation enforcing feature sparsity in archetypes - essential for high-dimensional data where interpretability is a priority and feature selection is desirable

BiarchetypalAnalysis:

Dual-direction analysis revealing patterns in both observation and feature spaces simultaneously - ideal for datasets where understanding relationships between features is as important as clustering observations

Basic Usage:

from archetypax.models import ArchetypalAnalysis

model = ArchetypalAnalysis(n_archetypes=5) model.fit(data) archetypes = model.get_archetypes()