"""Sparse Archetypal Analysis model utilizing JAX."""
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from scipy.spatial import ConvexHull
from archetypax.logger import get_logger, get_message
from archetypax.models.archetypes import ImprovedArchetypalAnalysis
[docs]
class SparseArchetypalAnalysis(ImprovedArchetypalAnalysis):
"""Archetypal Analysis incorporating sparsity constraints on archetypes.
This implementation enhances the ImprovedArchetypalAnalysis by introducing
sparsity constraints to the archetypes, thereby improving interpretability,
particularly in high-dimensional datasets.
"""
[docs]
def __init__(
self,
n_archetypes: int,
max_iter: int = 500,
tol: float = 1e-6,
random_seed: int = 42,
learning_rate: float = 0.001,
lambda_reg: float = 0.01,
lambda_sparsity: float = 0.1,
sparsity_method: str = "l1",
normalize: bool = False,
projection_method: str = "cbap",
projection_alpha: float = 0.1,
archetype_init_method: str = "directional",
min_volume_factor: float = 0.001,
**kwargs,
):
"""Initialize the Sparse Archetypal Analysis model.
Args:
n_archetypes: Number of archetypes to extract.
max_iter: Maximum number of iterations for optimization.
tol: Convergence tolerance.
random_seed: Random seed for reproducibility.
learning_rate: Learning rate for the optimizer.
lambda_reg: Regularization strength for weights.
lambda_sparsity: Regularization strength for archetype sparsity.
sparsity_method: Method for enforcing sparsity ("l1", "l0_approx", or "feature_selection").
normalize: Whether to normalize data prior to fitting.
projection_method: Method for projecting archetypes ("cbap", "convex_hull", or "knn").
projection_alpha: Strength of projection (0-1).
archetype_init_method: Method for initializing archetypes
("directional", "qhull", "kmeans_pp").
min_volume_factor: Minimum volume factor to prevent degeneracy (0-1).
**kwargs: Additional keyword arguments.
"""
super().__init__(
n_archetypes=n_archetypes,
max_iter=max_iter,
tol=tol,
random_seed=random_seed,
learning_rate=learning_rate,
lambda_reg=lambda_reg,
normalize=normalize,
projection_method=projection_method,
projection_alpha=projection_alpha,
archetype_init_method=archetype_init_method,
**kwargs,
)
# Initialize a class-specific logger with the updated class name.
self.logger = get_logger(f"{__name__}.{self.__class__.__name__}")
self.logger.info(
get_message(
"init",
"model_init",
model_name=self.__class__.__name__,
n_archetypes=n_archetypes,
sparsity_method=sparsity_method,
lambda_sparsity=lambda_sparsity,
min_volume_factor=min_volume_factor,
)
)
self.rng_key = jax.random.key(random_seed)
self.lambda_sparsity = lambda_sparsity
self.sparsity_method = sparsity_method
self.min_volume_factor = min_volume_factor # Parameter controlling the minimum volume of the convex hull.
self.early_stopping_patience = kwargs.get("early_stopping_patience", 100)
self.verbose_level = kwargs.get("verbose_level", 1)
[docs]
@partial(jax.jit, static_argnums=(0,))
def loss_function(self, archetypes, weights, X):
"""JIT-compiled loss function incorporating a sparsity constraint on archetypes."""
archetypes_f32 = archetypes.astype(jnp.float32)
weights_f32 = weights.astype(jnp.float32)
X_f32 = X.astype(jnp.float32)
X_reconstructed = jnp.matmul(weights_f32, archetypes_f32)
reconstruction_loss = jnp.mean(jnp.sum((X_f32 - X_reconstructed) ** 2, axis=1))
# Calculate entropy for weights (higher values indicate uniform weights, lower values indicate sparse weights).
weight_entropy = -jnp.sum(weights_f32 * jnp.log(weights_f32 + 1e-10), axis=1)
weight_entropy_reg = jnp.mean(weight_entropy)
# Introduce an incentive for archetypes to remain near the boundary of the convex hull.
boundary_incentive = self._calculate_boundary_proximity(archetypes_f32, X_f32)
# Compute the sparsity penalty based on the selected method.
if self.sparsity_method == "l1":
# L1 regularization to encourage sparsity in archetypes.
sparsity_penalty = jnp.mean(jnp.sum(jnp.abs(archetypes_f32), axis=1))
elif self.sparsity_method == "l0_approx":
# Approximation of the L0 norm using a continuous function.
# This provides a smoother approximation for counting non-zero elements.
epsilon = 1e-6
sparsity_penalty = jnp.mean(jnp.sum(1 - jnp.exp(-(archetypes_f32**2) / epsilon), axis=1))
elif self.sparsity_method == "feature_selection":
# Encourages each archetype to concentrate on a subset of features
# by penalizing uniform distribution across features.
archetype_entropy = -jnp.sum(archetypes_f32 * jnp.log(archetypes_f32 + 1e-10), axis=1)
sparsity_penalty = jnp.mean(archetype_entropy)
else:
# Default to L1 if the method is unrecognized.
sparsity_penalty = jnp.mean(jnp.sum(jnp.abs(archetypes_f32), axis=1))
# Calculate the archetype diversity penalty based on pairwise similarity.
n_archetypes = archetypes_f32.shape[0]
archetype_diversity_penalty = 0.0
if n_archetypes > 1:
# Compute the normalized cosine similarity matrix between archetypes.
norms = jnp.sqrt(jnp.sum(archetypes_f32**2, axis=1, keepdims=True))
normalized_archetypes = archetypes_f32 / jnp.maximum(norms, 1e-10)
similarity_matrix = jnp.dot(normalized_archetypes, normalized_archetypes.T)
# Exclude diagonal elements (self-similarity is always 1).
mask = jnp.ones((n_archetypes, n_archetypes)) - jnp.eye(n_archetypes)
masked_similarities = similarity_matrix * mask
# Retrieve the maximum similarity (higher values indicate a problem).
archetype_diversity_penalty = jax.device_get(jnp.mean(jnp.maximum(masked_similarities, 0)))
# Add the archetype diversity penalty to the total loss (higher similarity = lower diversity penalty).
diversity_weight = 0.1
# Combined loss incorporating reconstruction, regularizations, boundary incentive, and diversity.
total_loss = (
reconstruction_loss
+ self.lambda_reg * weight_entropy_reg
+ self.lambda_sparsity * sparsity_penalty
- 0.001 * boundary_incentive
+ diversity_weight * archetype_diversity_penalty
)
return total_loss.astype(jnp.float32)
def _calculate_simplex_volume(self, archetypes):
"""Calculate the volume of the simplex formed by the archetypes.
This is utilized to ensure that the archetypes do not collapse into a degenerate subspace.
Args:
archetypes: Array of shape (n_archetypes, n_features).
Returns:
Approximate volume of the simplex.
"""
n_archetypes, n_features = archetypes.shape
# If there are fewer archetypes than dimensions + 1, the volume is technically zero.
# Instead, we will compute a proxy metric based on pairwise distances.
if n_archetypes <= n_features:
# Calculate pairwise distances between archetypes.
pairwise_distances = np.zeros((n_archetypes, n_archetypes))
for i in range(n_archetypes):
for j in range(i + 1, n_archetypes):
dist = np.linalg.norm(archetypes[i] - archetypes[j])
pairwise_distances[i, j] = pairwise_distances[j, i] = dist
# Use the product of pairwise distances as a proxy for volume.
# Higher values indicate that archetypes are more spread out.
volume_proxy = np.sum(pairwise_distances) / (n_archetypes * (n_archetypes - 1) / 2)
return volume_proxy
else:
try:
# When there are enough points, attempt to compute the actual convex hull volume.
hull = ConvexHull(archetypes)
return hull.volume
except Exception:
# Fallback to the pairwise distance approach if the convex hull computation fails.
pairwise_distances = np.zeros((n_archetypes, n_archetypes))
for i in range(n_archetypes):
for j in range(i + 1, n_archetypes):
dist = np.linalg.norm(archetypes[i] - archetypes[j])
pairwise_distances[i, j] = pairwise_distances[j, i] = dist
volume_proxy = np.sum(pairwise_distances) / (n_archetypes * (n_archetypes - 1) / 2)
return volume_proxy
[docs]
@partial(jax.jit, static_argnums=(0,))
def update_archetypes(self, archetypes, weights, X) -> jnp.ndarray:
"""Update archetypes with an additional step to promote sparsity."""
# First, perform the standard archetype update.
archetypes_updated = super().update_archetypes(archetypes, weights, X)
# Apply an additional sparsity promotion based on the selected method.
if self.sparsity_method == "feature_selection":
# For feature selection, apply soft thresholding to enhance feature selectivity.
# This step retains the largest values in each archetype while shrinking smaller values.
# Calculate thresholds for each archetype (adaptive thresholding).
thresholds = jnp.percentile(archetypes_updated, 50, axis=1, keepdims=True)
# Soft thresholding: shrink values below the threshold.
shrinkage_factor = 0.7 # Controls the aggressiveness of shrinking small values.
mask = archetypes_updated < thresholds
archetypes_updated = jnp.where(mask, archetypes_updated * shrinkage_factor, archetypes_updated)
# Re-normalize to maintain simplex constraints.
row_sums = jnp.sum(archetypes_updated, axis=1, keepdims=True)
archetypes_updated = archetypes_updated / jnp.maximum(1e-10, row_sums)
# Calculate feature-wise variance of archetypes to identify potential degeneracy.
# If variance is too low in some features across archetypes, it suggests potential degeneracy.
archetype_variance = jnp.var(archetypes_updated, axis=0, keepdims=True)
# Introduce small noise in the direction of low variance to prevent degeneracy.
# Scale noise inversely with the variance to target low-variance dimensions.
noise_scale = 0.01
# Use jax.random instead of jnp.random
_, noise_key = jax.random.split(self.rng_key) # Fixed seed for deterministic noise
noise = jax.random.uniform(noise_key, shape=archetypes_updated.shape) - 0.5 # Zero-centered noise
# Scale noise inversely proportional to variance (more noise where variance is low).
# Add a small epsilon to avoid division by zero.
variance_scaling = noise_scale / (jnp.sqrt(archetype_variance) + 1e-8)
scaled_noise = noise * variance_scaling
# Apply noise selectively to avoid disrupting the sparsity pattern.
# Only add noise where the archetype elements are already non-zero.
archetypes_with_noise = archetypes_updated + scaled_noise * (archetypes_updated > 1e-5)
# Re-normalize to maintain simplex constraints.
row_sums = jnp.sum(archetypes_with_noise, axis=1, keepdims=True)
archetypes_with_noise = archetypes_with_noise / jnp.maximum(1e-10, row_sums)
# Ensure archetypes remain within the convex hull.
centroid = jnp.mean(X, axis=0)
# Process each archetype to ensure it resides within the convex hull.
def _constrain_to_convex_hull(archetype):
# Direction from centroid to archetype.
direction = archetype - centroid
direction_norm = jnp.linalg.norm(direction)
# Handle near-zero norm case.
normalized_direction = jnp.where(
direction_norm > 1e-10, direction / direction_norm, jnp.zeros_like(direction)
)
# Project all points onto this direction.
projections = jnp.dot(X - centroid, normalized_direction)
# Identify the maximum projection (extreme point in this direction).
max_projection = jnp.max(projections)
# Calculate the archetype projection along this direction.
archetype_projection = jnp.dot(archetype - centroid, normalized_direction)
# Scale factor to bring the archetype inside the convex hull if it is outside.
# Apply a small margin (0.99) to ensure it is strictly inside.
scale_factor = jnp.where(
archetype_projection > max_projection,
0.99 * max_projection / (archetype_projection + 1e-10),
1.0,
)
# Apply the scaling to the direction vector.
constrained_archetype = centroid + scale_factor * (archetype - centroid)
return constrained_archetype
# Apply the constraint to each archetype.
constrained_archetypes = jax.vmap(_constrain_to_convex_hull)(archetypes_with_noise)
return jnp.asarray(constrained_archetypes)
[docs]
def diversify_archetypes(self, archetypes, X):
"""Non-differentiable post-processing step to ensure archetype diversity.
This method guarantees that the archetypes form a non-degenerate simplex with sufficient volume.
It is invoked outside the JAX-compiled update steps, as it employs non-differentiable operations.
Args:
archetypes: Current archetypes array.
X: Data matrix.
Returns:
Diversified archetypes array.
"""
archetypes_np = np.array(archetypes)
X_np = np.array(X)
n_archetypes = archetypes_np.shape[0]
# Calculate initial volume or proxy.
initial_volume = self._calculate_simplex_volume(archetypes_np)
# If the volume is exceedingly small, attempt to increase it.
if initial_volume < self.min_volume_factor:
self.logger.info(
f"Detected a potentially degenerate archetype configuration. "
f"Volume proxy: {initial_volume:.6f}. Attempting to diversify."
)
# Calculate centroid.
centroid = np.mean(X_np, axis=0)
# For each archetype, push it away from other archetypes.
for i in range(n_archetypes):
# Calculate direction away from the average of other archetypes.
other_archetypes = np.delete(archetypes_np, i, axis=0)
other_centroid = np.mean(other_archetypes, axis=0)
# Direction away from other archetypes.
direction = archetypes_np[i] - other_centroid
direction_norm = np.linalg.norm(direction)
if direction_norm > 1e-10:
normalized_direction = direction / direction_norm
# Determine how far we can push in this direction while remaining within the convex hull.
projections = np.dot(X_np - centroid, normalized_direction)
max_projection = np.max(projections)
current_projection = np.dot(archetypes_np[i] - centroid, normalized_direction)
# Push outward, but remain within the convex hull.
# Utilize a blend of the current position and the extreme point.
blend_factor = 0.5 # Move halfway toward the extreme point.
if current_projection < max_projection:
target_projection = current_projection + blend_factor * (max_projection - current_projection)
archetypes_np[i] = centroid + normalized_direction * target_projection
# Re-normalize to maintain simplex constraints.
row_sums = np.sum(archetypes_np, axis=1, keepdims=True)
archetypes_np = archetypes_np / np.maximum(1e-10, row_sums)
# Verify improvement.
new_volume = self._calculate_simplex_volume(archetypes_np)
self.logger.info(
f"After diversification, volume proxy changed from {initial_volume:.6f} to {new_volume:.6f}"
)
return jnp.array(archetypes_np)
[docs]
def get_archetype_sparsity(self) -> np.ndarray:
"""Calculate the sparsity of each archetype.
Returns:
Array containing the sparsity score for each archetype.
Higher values indicate more sparse archetypes.
"""
if not hasattr(self, "archetypes") or self.archetypes is None:
raise ValueError("The model has not yet been fitted.")
archetypes = self.archetypes
n_archetypes, n_features = archetypes.shape
sparsity_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Calculate the Gini coefficient as a measure of sparsity.
# (An alternative to directly counting zeroes, which isn't differentiable).
sorted_values = np.sort(np.abs(archetypes[i]))
cumsum = np.cumsum(sorted_values)
gini = 1 - 2 * np.sum(cumsum) / (n_features * np.sum(sorted_values))
sparsity_scores[i] = gini
return sparsity_scores
[docs]
def fit(
self,
X: np.ndarray,
normalize: bool = False,
**kwargs,
) -> "SparseArchetypalAnalysis":
"""Fit the Sparse Archetypal Analysis model to the data.
Args:
X: Input data matrix of shape (n_samples, n_features).
normalize: Whether to normalize data prior to fitting.
**kwargs: Additional keyword arguments.
Returns:
Fitted model instance.
"""
self.logger.info(
get_message(
"training",
"model_training_start",
sparsity_method=self.sparsity_method,
lambda_sparsity=self.lambda_sparsity,
)
)
X_np = X.values if hasattr(X, "values") else X
# Utilize the parent class fit method.
model = super().fit(X_np, normalize, **kwargs)
# Apply post-processing to ensure archetype diversity.
# This is executed outside the JAX-compiled optimization.
if hasattr(model, "archetypes") and model.archetypes is not None:
X_jax = jnp.array(X_np if not normalize else (X_np - model.X_mean) / model.X_std)
model.archetypes = self.diversify_archetypes(model.archetypes, X_jax)
return model