Usage

Basic Usage

archetypax provides a scikit-learn compatible API for performing archetypal analysis. Here’s a basic example:

import numpy as np
from archetypax import ArchetypalAnalysis

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 5)  # 100 samples, 5 features

# Initialize and fit the model
model = ArchetypalAnalysis(n_archetypes=3, random_state=42)
model.fit(X)

# Get archetypes
archetypes = model.archetypes_
print("Archetypes shape:", archetypes.shape)

# Transform new data to get weights
X_new = np.random.rand(10, 5)  # 10 new samples
weights = model.transform(X_new)
print("Weights shape:", weights.shape)

# Reconstruct data from weights and archetypes
X_reconstructed = model.inverse_transform(weights)
print("Reconstructed data shape:", X_reconstructed.shape)

Advanced Usage

Customizing Optimization Parameters

archetypax allows fine-tuning of the optimization process:

from archetypax import ArchetypalAnalysis

model = ArchetypalAnalysis(
    n_archetypes=5,
    max_iter=1000,
    tol=1e-6,
    learning_rate=0.01,
    batch_size=32,
    random_state=42
)
model.fit(X)

Using the Improved Model

For enhanced performance, archetypax offers an improved implementation:

from archetypax import ImprovedArchetypalAnalysis

model = ImprovedArchetypalAnalysis(
    n_archetypes=4,
    convex_hull_init=True,  # Initialize archetypes near the convex hull
    regularization=0.001,   # Add regularization for stability
    random_state=42
)
model.fit(X)

Using Biarchetypal Analysis

For more expressive representations, archetypax provides biarchetypal analysis which uses two sets of archetypes:

import numpy as np
from archetypax import BiarchetypalAnalysis

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 5)  # 100 samples, 5 features

# Initialize and fit the model with two sets of archetypes
model = BiarchetypalAnalysis(
    n_archetypes_first=2,   # Number of archetypes in the first set
    n_archetypes_second=2,  # Number of archetypes in the second set
    mixture_weight=0.5,     # Weight for mixing the two archetype sets (0-1)
    max_iter=500,
    random_state=42
)
model.fit(X)

# Get both sets of archetypes
positive_archetypes, negative_archetypes = model.get_all_archetypes()
print("Positive archetypes shape:", positive_archetypes.shape)
print("Negative archetypes shape:", negative_archetypes.shape)

# Get both sets of weights
positive_weights, negative_weights = model.get_all_weights()
print("Positive weights shape:", positive_weights.shape)
print("Negative weights shape:", negative_weights.shape)

# Reconstruct data using both sets of archetypes
X_reconstructed = model.reconstruct()
print("Reconstructed data shape:", X_reconstructed.shape)

Visualization

archetypax includes visualization utilities for exploring archetypal analysis results:

import matplotlib.pyplot as plt
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Fit model
model = ArchetypalAnalysis(n_archetypes=3)
model.fit(X)

# Plot archetypes in 2D projection
ArchetypalAnalysisVisualizer.plot_archetypes_2d(model, X)
plt.show()

# Plot membership weights
weights = model.transform(X)
ArchetypalAnalysisVisualizer.plot_membership_weights(weights)
plt.show()

# Plot archetype profiles
ArchetypalAnalysisVisualizer.plot_archetype_profiles(model, feature_names=['F1', 'F2', 'F3', 'F4', 'F5'])
plt.show()

Evaluation Metrics

Evaluate the quality of your archetypal analysis:

from archetypax.tools.evaluation import ArchetypalAnalysisEvaluator

# Fit model
model = ArchetypalAnalysis(n_archetypes=4)
model.fit(X)

# Create an evaluator
evaluator = ArchetypalAnalysisEvaluator(model)

# Calculate reconstruction error
error = evaluator.reconstruction_error(X, metric="frobenius")
print(f"Reconstruction error: {error:.4f}")

# Calculate explained variance
variance = evaluator.explained_variance(X)
print(f"Explained variance: {variance:.4f}")

# Get comprehensive evaluation metrics
metrics = evaluator.comprehensive_evaluation(X)
print("Clustering metrics:", metrics["clustering"])
print("Separation metrics:", evaluator.archetype_separation())

Interpretation Tools

Interpret your archetypal analysis results:

from archetypax.tools.interpret import ArchetypalAnalysisInterpreter

# Create an interpreter
interpreter = ArchetypalAnalysisInterpreter()

# Add a fitted model
model = ArchetypalAnalysis(n_archetypes=3)
model.fit(X)
interpreter.add_model(3, model)  # Add model with key=3 (number of archetypes)

# Add more models with different numbers of archetypes
model4 = ArchetypalAnalysis(n_archetypes=4)
model4.fit(X)
interpreter.add_model(4, model4)

# Calculate feature distinctiveness for archetypes
distinctiveness = interpreter.feature_distinctiveness(model.archetypes)
print("Feature distinctiveness:", distinctiveness)

# Find optimal number of archetypes
optimal_k = interpreter.find_optimal_k(X, k_range=[2, 3, 4, 5, 6])
print(f"Optimal number of archetypes: {optimal_k}")

Integration with scikit-learn

archetypax integrates seamlessly with scikit-learn pipelines:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from archetypax import ArchetypalAnalysis

# Create a pipeline
pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('archetype', ArchetypalAnalysis(n_archetypes=3))
])

# Fit and transform
pipeline.fit(X)
weights = pipeline.transform(X)

Sparse Archetypal Analysis

For datasets where interpretability is paramount, archetypax offers SparseArchetypalAnalysis which enforces sparsity constraints on archetypes:

from archetypax import SparseArchetypalAnalysis
import numpy as np

# Generate synthetic data
np.random.seed(42)
X = np.random.rand(100, 10)  # 100 samples, 10 features

# Initialize sparse archetypal analysis
model = SparseArchetypalAnalysis(
    n_archetypes=3,
    lambda_sparsity=0.1,     # Controls the strength of sparsity regularization
    sparsity_method="l1",    # Options: "l1", "l0_approx", "feature_selection"
    max_iter=500
)

# Fit the model
model.fit(X)

# Get archetypes
archetypes = model.archetypes
print("Archetypes shape:", archetypes.shape)

# Calculate sparsity scores for each archetype
sparsity_scores = model.get_archetype_sparsity()
print("Archetype sparsity scores:", sparsity_scores)

# Transform data to get weights
weights = model.transform(X)

# Visualize sparse archetypes
import matplotlib.pyplot as plt
from archetypax.tools.visualization import ArchetypalAnalysisVisualizer

# Plot archetype profiles showing the sparse feature representation
ArchetypalAnalysisVisualizer.plot_archetype_profiles(
    model,
    feature_names=[f"F{i+1}" for i in range(X.shape[1])],
    highlight_threshold=0.1  # Highlight features above this threshold
)
plt.show()

Benefits of sparse archetypes include:

  • Improved Interpretability: Each archetype focuses on fewer, more meaningful features

  • Feature Selection: Automatically identifies the most important features for each archetype

  • Reduced Overfitting: Sparsity acts as a form of regularization

  • Clearer Patterns: Makes archetypal patterns more distinct and easier to interpret

The sparsity_method parameter allows you to choose different approaches to sparsity:

  • "l1": Traditional L1 regularization for general sparsity

  • "l0_approx": Approximation of L0 norm for more aggressive sparsity

  • "feature_selection": Focuses on selecting distinct feature subsets for each archetype