"""Biarchetypal Analysis: Dual-perspective archetype discovery using JAX.
This module implements Biarchetypal Analysis (BA), which extends traditional
Archetypal Analysis by simultaneously identifying archetypes in both observation
space (rows) and feature space (columns). This dual-perspective approach enables
more comprehensive data understanding, revealing patterns that would remain hidden
in single-direction analysis.
"""
from functools import partial
from typing import Any, TypeVar
import jax
import jax.numpy as jnp
import numpy as np
import optax
from archetypax.logger import get_logger, get_message
from .archetypes import ImprovedArchetypalAnalysis
T = TypeVar("T", bound=np.ndarray)
[docs]
class BiarchetypalAnalysis(ImprovedArchetypalAnalysis):
"""Biarchetypal Analysis for dual-directional pattern discovery.
This implementation extends archetypal analysis to simultaneously identify
extreme patterns in both observations (rows) and features (columns), offering
a richer understanding of data structure. Traditional archetypal analysis
only identifies patterns in observation space, missing crucial feature-level insights.
By factorizing the data matrix X as:
X ≃ alpha·beta·X·theta·gamma
BA provides several advantages:
- Captures both observation-level and feature-level patterns
- Enables cross-modal analysis between observations and features
- Creates a more compact and interpretable representation via biarchetypes
- Reveals latent relationships that single-directional methods cannot detect
Based on the work by Alcacer et al., "Biarchetype analysis: simultaneous learning
of observations and features based on extremes."
"""
[docs]
def __init__(
self,
n_row_archetypes: int,
n_col_archetypes: int,
max_iter: int = 500,
tol: float = 1e-6,
random_seed: int = 42,
learning_rate: float = 0.001,
projection_method: str = "default",
lambda_reg: float = 0.01,
**kwargs,
):
"""Initialize the Biarchetypal Analysis model.
Args:
n_row_archetypes: Number of row archetypes - controls expressiveness in
observation space (rows)
n_col_archetypes: Number of column archetypes - controls expressiveness in
feature space (columns)
max_iter: Maximum optimization iterations - higher values enable better
convergence at computational cost
tol: Convergence tolerance for early stopping - smaller values yield more
precise solutions but require more iterations
random_seed: Random seed for reproducibility across runs
learning_rate: Gradient descent step size - critical balance between
convergence speed and stability
projection_method: Method for projecting archetypes to extreme points:
"default" uses convex boundary approximation
lambda_reg: Regularization strength - controls sparsity/smoothness tradeoff
in archetype weights
**kwargs: Additional parameters including:
- early_stopping_patience: Iterations with no improvement before stopping
- verbose_level/logger_level: Controls logging detail
"""
# Initialize using parent class with the row archetypes
# (we'll handle column archetypes separately)
super().__init__(
n_archetypes=n_row_archetypes,
max_iter=max_iter,
tol=tol,
random_seed=random_seed,
learning_rate=learning_rate,
**kwargs,
)
# Initialize class-specific logger
if isinstance(kwargs.get("logger_level"), str) and kwargs.get("logger_level") is not None:
logger_level = kwargs["logger_level"].upper()
elif isinstance(kwargs.get("logger_level"), int) and kwargs.get("logger_level") is not None:
logger_level = {
0: "DEBUG",
1: "INFO",
2: "WARNING",
3: "ERROR",
4: "CRITICAL",
}[kwargs["logger_level"]]
elif "logger_level" not in kwargs and "verbose_level" in kwargs and kwargs["verbose_level"] is not None:
logger_level = {
4: "DEBUG",
3: "INFO",
2: "WARNING",
1: "ERROR",
0: "CRITICAL",
}[kwargs["verbose_level"]]
else:
logger_level = "ERROR"
self.logger = get_logger(f"{__name__}.{self.__class__.__name__}", level=logger_level)
self.logger.info(
get_message(
"init",
"model_init",
model_name=self.__class__.__name__,
n_row_archetypes=n_row_archetypes,
n_col_archetypes=n_col_archetypes,
max_iter=max_iter,
tol=tol,
random_seed=random_seed,
learning_rate=learning_rate,
projection_method=projection_method,
lambda_reg=lambda_reg,
)
)
# Store biarchetypal specific parameters
self.n_row_archetypes = n_row_archetypes
self.n_col_archetypes = n_col_archetypes
self.lambda_reg = lambda_reg
self.random_seed = random_seed
# Will be set during fitting
self.alpha: np.ndarray | None = None # Row coefficients (n_samples, n_row_archetypes)
self.beta: np.ndarray | None = None # Row archetypes (n_row_archetypes, n_samples)
self.theta: np.ndarray | None = None # Column archetypes (n_features, n_col_archetypes)
self.gamma: np.ndarray | None = None # Column coefficients (n_col_archetypes, n_features)
self.biarchetypes: np.ndarray | None = None # β·X·θ (n_row_archetypes, n_col_archetypes)
self.early_stopping_patience = kwargs.get("early_stopping_patience", 100)
[docs]
@partial(jax.jit, static_argnums=(0,))
def loss_function(self, params: dict[str, jnp.ndarray], X: jnp.ndarray) -> jnp.ndarray:
"""Calculate the composite reconstruction loss for biarchetypal factorization.
This core objective function balances reconstruction quality with sparsity
promotion to ensure interpretable representations. Unlike standard AA,
the biarchetypal loss operates on a four-factor decomposition, requiring
careful numerical handling to prevent instability during optimization.
The loss promotes three key properties:
1. Accurate data reconstruction through the biarchetypal representation
2. Sparse coefficients for interpretable patterns
3. Numerical stability through explicit type control
Args:
params: Dictionary containing the four model matrices:
- alpha: Row coefficients (n_samples, n_row_archetypes)
- beta: Row archetypes (n_row_archetypes, n_samples)
- theta: Column archetypes (n_features, n_col_archetypes)
- gamma: Column coefficients (n_col_archetypes, n_features)
X: Data matrix (n_samples, n_features)
Returns:
Combined loss value incorporating reconstruction and regularization terms
"""
# Convert to float32 for better numerical stability
alpha = params["alpha"].astype(jnp.float32) # (n_samples, n_row_archetypes)
beta = params["beta"].astype(jnp.float32) # (n_row_archetypes, n_samples)
theta = params["theta"].astype(jnp.float32) # (n_features, n_col_archetypes)
gamma = params["gamma"].astype(jnp.float32) # (n_col_archetypes, n_features)
X_f32 = X.astype(jnp.float32)
# Calculate the reconstruction: X ≃ alpha·beta·X·theta·gamma
# Optimize matrix multiplications to reduce memory usage
inner_product = jnp.matmul(jnp.matmul(beta, X_f32), theta) # (n_row_archetypes, n_col_archetypes)
reconstruction = jnp.matmul(jnp.matmul(alpha, inner_product), gamma) # (n_samples, n_features)
# Calculate the reconstruction error (element-wise MSE)
reconstruction_loss = jnp.mean(jnp.sum((X_f32 - reconstruction) ** 2, axis=1))
# Add regularization to encourage sparsity by minimizing entropy
alpha_entropy = jnp.sum(alpha * jnp.log(alpha + 1e-10), axis=1) # Higher values promote sparsity
gamma_entropy = jnp.sum(gamma * jnp.log(gamma + 1e-10), axis=0) # Higher values promote sparsity
entropy_reg = jnp.mean(alpha_entropy) + jnp.mean(gamma_entropy)
return (reconstruction_loss - self.lambda_reg * entropy_reg).astype(jnp.float32)
[docs]
@partial(jax.jit, static_argnums=(0,))
def project_row_coefficients(self, coefficients: jnp.ndarray) -> jnp.ndarray:
"""Project row coefficients to satisfy simplex constraints.
This projection is essential for maintaining valid convex combinations
in the observation space. The simplex constraint (non-negative weights
summing to 1) ensures that each data point is represented as a proper
weighted combination of row archetypes. Without this constraint, the
model would lose its interpretability and might generate unrealistic
representations.
The implementation includes numerical safeguards to prevent division by zero
and ensure stable optimization even with extreme weight values.
Args:
coefficients: Row coefficient matrix (n_samples, n_row_archetypes)
Returns:
Projected coefficients satisfying simplex constraints
(non-negative, sum to 1)
"""
eps = 1e-10
coefficients = jnp.maximum(eps, coefficients)
sum_coeffs = jnp.sum(coefficients, axis=1, keepdims=True)
sum_coeffs = jnp.maximum(eps, sum_coeffs)
return coefficients / sum_coeffs
[docs]
@partial(jax.jit, static_argnums=(0,))
def project_col_coefficients(self, coefficients: jnp.ndarray) -> jnp.ndarray:
"""Project column coefficients to satisfy simplex constraints.
This projection enforces valid convex combinations in the feature space,
which differs critically from row coefficient projection. Feature weights
must sum to 1 across columns (not rows), ensuring each feature is properly
represented by column archetypes.
This axis-specific projection is a key distinction between standard AA and
biarchetypal analysis, enabling the dual-directional nature of the model.
Args:
coefficients: Column coefficient matrix (n_col_archetypes, n_features)
Returns:
Projected coefficients with each feature's weights summing to 1,
maintaining valid convex combinations in feature space
"""
eps = 1e-10
coefficients = jnp.maximum(eps, coefficients)
sum_coeffs = jnp.sum(coefficients, axis=0, keepdims=True)
sum_coeffs = jnp.maximum(eps, sum_coeffs)
return coefficients / sum_coeffs
[docs]
@partial(jax.jit, static_argnums=(0,))
def project_row_archetypes(self, archetypes: jnp.ndarray, X: jnp.ndarray) -> jnp.ndarray:
"""Project row archetypes to the convex hull boundary of data points.
This critical operation ensures row archetypes remain at meaningful extremes
of the observation space, where they represent distinct, interpretable patterns.
Without this projection, archetypes would tend to collapse toward the data
centroid during optimization, losing their representative power.
The implementation uses an adaptive multi-point boundary approximation that:
1. Identifies extreme directions from the data centroid
2. Selects multiple boundary points along each direction
3. Creates weighted combinations that maximize distinctiveness
4. Maintains numeric stability throughout the process
Args:
archetypes: Row archetype matrix (n_row_archetypes, n_samples)
X: Data matrix (n_samples, n_features)
Returns:
Projected row archetypes positioned at meaningful boundaries
of the data's convex hull
"""
# Calculate the data centroid as our reference point
centroid = jnp.mean(X, axis=0) # Shape: (n_features,)
def _project_to_boundary(archetype):
"""Project a single archetype to the boundary of the convex hull."""
# Step 1: Calculate direction from centroid to archetype representation
weighted_representation = jnp.matmul(archetype, X) # Shape: (n_features,)
direction = weighted_representation - centroid
direction_norm = jnp.linalg.norm(direction)
normalized_direction = jnp.where(
direction_norm > 1e-10,
direction / direction_norm,
jax.random.normal(jax.random.PRNGKey(0), direction.shape) / jnp.sqrt(direction.shape[0]),
)
# Step 2: Project all data points onto this direction vector
projections = jnp.dot(X - centroid, normalized_direction) # Shape: (n_samples,)
# Step 3: Find multiple extreme points with adaptive k selection
k = min(5, X.shape[0] // 10 + 2) # Adaptive k based on dataset size
top_k_indices = jnp.argsort(projections)[-k:]
top_k_projections = projections[top_k_indices]
# Step 4: Calculate weights with emphasis on the most extreme points
weights_unnormalized = jnp.exp(top_k_projections - jnp.max(top_k_projections))
weights = weights_unnormalized / jnp.sum(weights_unnormalized)
# Step 5: Create a weighted combination of extreme points
multi_hot = jnp.zeros_like(archetype)
for i in range(k):
idx = top_k_indices[i]
multi_hot = multi_hot.at[idx].set(weights[i])
# Step 6: Mix with original archetype for stability and convergence
alpha = 0.8 # Stronger pull toward extreme points for better diversity
projected = alpha * multi_hot + (1 - alpha) * archetype
# Step 7: Apply simplex constraints with numerical stability safeguards
projected = jnp.maximum(1e-10, projected)
sum_projected = jnp.sum(projected)
projected = jnp.where(
sum_projected > 1e-10,
projected / sum_projected,
jnp.ones_like(projected) / projected.shape[0],
)
return projected
# Apply the projection function to each row archetype in parallel
projected_archetypes = jax.vmap(_project_to_boundary)(archetypes)
return jnp.asarray(projected_archetypes)
[docs]
@partial(jax.jit, static_argnums=(0,))
def project_col_archetypes(self, archetypes: jnp.ndarray, X: jnp.ndarray) -> jnp.ndarray:
"""Project column archetypes to the boundary of the feature space.
This critical counterpart to row archetype projection ensures column archetypes
represent distinct feature patterns. While conceptually similar to row projection,
this operation works in the transposed space, treating features as observations
and finding extremes among them.
Without this specialized projection, the feature archetypes would not capture
meaningful feature combinations, undermining the dual-perspective advantage
of biarchetypal analysis.
The implementation:
1. Transposes the problem to work in feature space
2. Identifies feature combinations that represent extremes
3. Creates boundary points through weighted feature combinations
4. Maintains numerical stability throughout
Args:
archetypes: Column archetype matrix (n_features, n_col_archetypes)
X: Data matrix (n_samples, n_features)
Returns:
Projected column archetypes positioned at the boundaries of feature space,
representing distinct feature patterns
"""
# Transpose X to work with features as data points in feature space
X_T = X.T # Shape: (n_features, n_samples)
# Calculate the feature centroid as our reference point
centroid = jnp.mean(X_T, axis=0) # Shape: (n_samples,)
def _project_feature_to_boundary(archetype: jnp.ndarray) -> jnp.ndarray:
"""Project a single column archetype to the boundary of the feature convex hull."""
# Step 1: Calculate direction in sample space using weighted features
weighted_features = archetype[:, jnp.newaxis] * X_T # Shape: (n_features, n_samples)
direction = jnp.sum(weighted_features, axis=0) - centroid # Shape: (n_samples,)
direction_norm = jnp.linalg.norm(direction)
normalized_direction = jnp.where(
direction_norm > 1e-10,
direction / direction_norm,
jax.random.normal(jax.random.PRNGKey(0), direction.shape) / jnp.sqrt(direction.shape[0]),
)
# Step 2: Project all features onto this direction to measure extremeness
projections = jnp.dot(X_T, normalized_direction) # Shape: (n_features,)
# Step 3: Find multiple extreme features with adaptive k selection
k = min(5, X.shape[1] // 10 + 2) # Adaptive k based on feature space size
top_k_indices = jnp.argsort(projections)[-k:]
top_k_projections = projections[top_k_indices]
# Step 4: Calculate weights with emphasis on the most extreme features
weights_unnormalized = jnp.exp(top_k_projections - jnp.max(top_k_projections))
weights = weights_unnormalized / jnp.sum(weights_unnormalized)
# Step 5: Create a weighted combination of extreme features
multi_hot = jnp.zeros_like(archetype)
for i in range(k):
idx = top_k_indices[i]
multi_hot = multi_hot.at[idx].set(weights[i])
# Step 6: Mix with original archetype for stability and convergence
alpha = 0.8 # Stronger pull toward extreme features for better diversity
projected = alpha * multi_hot + (1 - alpha) * archetype
# Step 7: Apply simplex constraints with numerical stability safeguards
projected = jnp.maximum(1e-10, projected)
sum_projected = jnp.sum(projected)
projected = jnp.where(
sum_projected > 1e-10,
projected / sum_projected,
jnp.ones_like(projected) / projected.shape[0],
)
return projected
# Apply the projection function to each column archetype in parallel
projected_archetypes = jax.vmap(_project_feature_to_boundary)(archetypes.T)
# Transpose the result back to original shape
return jnp.asarray(projected_archetypes.T)
[docs]
def fit(self, X: np.ndarray, normalize: bool = False, **kwargs) -> "BiarchetypalAnalysis":
"""Fit the Biarchetypal Analysis model to identify dual-perspective archetypes.
This core method performs the four-factor decomposition of the data matrix,
simultaneously discovering patterns in observation and feature spaces.
The implementation employs advanced optimization strategies including:
1. Sophisticated initialization for both row and column factors
2. Adaptive learning rate scheduling for stable convergence
3. Specialized projection operations to maintain meaningful boundaries
4. Careful numerical handling to prevent instability
5. Early stopping with convergence monitoring
These optimizations are essential due to the complexity of the four-factor model,
which is more challenging to optimize than standard Archetypal Analysis.
Args:
X: Data matrix (n_samples, n_features)
normalize: Whether to normalize features - essential for data with
different scales
**kwargs: Additional parameters for customizing the fitting process
Returns:
Self - fitted model instance with discovered biarchetypes
"""
X_np = X.values if hasattr(X, "values") else X
# Preprocess data: scale for improved stability
self.X_mean = np.mean(X_np, axis=0)
self.X_std = np.std(X_np, axis=0)
# Prevent division by zero
if self.X_std is not None:
self.X_std = np.where(self.X_std < 1e-10, np.ones_like(self.X_std), self.X_std)
if normalize:
X_scaled = (X_np - self.X_mean) / self.X_std
self.logger.info(get_message("data", "normalization", mean=self.X_mean, std=self.X_std))
else:
X_scaled = X_np.copy()
# Convert to JAX array with explicit dtype for better performance
X_jax = jnp.array(X_scaled, dtype=jnp.float32)
n_samples, n_features = X_jax.shape
# Log key data characteristics and model configuration
self.logger.info(f"Data shape: {X_jax.shape}")
self.logger.info(f"Data range: min={float(jnp.min(X_jax)):.4f}, max={float(jnp.max(X_jax)):.4f}")
self.logger.info(f"Row archetypes: {self.n_row_archetypes}")
self.logger.info(f"Column archetypes: {self.n_col_archetypes}")
# Initialize alpha (row coefficients) with more stable initialization
self.rng_key, subkey = jax.random.split(self.rng_key)
alpha_init = jax.random.uniform(
subkey, (n_samples, self.n_row_archetypes), minval=0.1, maxval=0.9, dtype=jnp.float32
)
alpha_init = self.project_row_coefficients(alpha_init)
# Initialize gamma (column coefficients)
self.rng_key, subkey = jax.random.split(self.rng_key)
gamma_init = jax.random.uniform(
subkey, (self.n_col_archetypes, n_features), minval=0.1, maxval=0.9, dtype=jnp.float32
)
gamma_init = self.project_col_coefficients(gamma_init)
# Initialize beta (row archetypes) using sophisticated k-means++ initialization
# This approach ensures diverse starting points that are well-distributed across the data space
self.rng_key, subkey = jax.random.split(self.rng_key)
# Step 1: Select initial centroids using k-means++ algorithm
# This ensures our archetypes start from diverse positions in the data space
selected_indices = jnp.zeros(self.n_row_archetypes, dtype=jnp.int32)
# Select first point randomly
first_idx = jax.random.randint(subkey, (), 0, n_samples)
selected_indices = selected_indices.at[0].set(first_idx)
# Select remaining points with probability proportional to squared distance
for i in range(1, self.n_row_archetypes):
# Calculate squared distance from each point to nearest existing centroid
min_dists = jnp.ones(n_samples) * float("inf")
# Update distances for each existing centroid
for j in range(i):
idx = selected_indices[j]
dists = jnp.sum((X_jax - X_jax[idx]) ** 2, axis=1)
min_dists = jnp.minimum(min_dists, dists)
# Zero out already selected points
for j in range(i):
idx = selected_indices[j]
min_dists = min_dists.at[idx].set(0.0)
# Select next point with probability proportional to squared distance
self.rng_key, subkey = jax.random.split(self.rng_key)
probs = min_dists / (jnp.sum(min_dists) + 1e-10)
next_idx = jax.random.choice(subkey, n_samples, p=probs)
selected_indices = selected_indices.at[i].set(next_idx)
# Step 2: Create one-hot encodings for selected points
beta_init = jnp.zeros((self.n_row_archetypes, n_samples), dtype=jnp.float32)
for i in range(self.n_row_archetypes):
idx = selected_indices[i]
beta_init = beta_init.at[i, idx].set(1.0)
# Step 3: Add controlled stochastic noise to promote exploration
# This prevents archetypes from being too rigidly defined at initialization
self.rng_key, subkey = jax.random.split(self.rng_key)
noise = jax.random.uniform(subkey, beta_init.shape, minval=0.0, maxval=0.05, dtype=jnp.float32)
beta_init = beta_init + noise
# Step 4: Ensure proper normalization to maintain simplex constraints
# Each row must sum to 1 to represent a valid convex combination
beta_init = beta_init / jnp.sum(beta_init, axis=1, keepdims=True)
self.logger.info("Row archetypes initialized with k-means++ strategy")
# Initialize theta (column archetypes) with advanced diversity-maximizing approach
# This ensures column archetypes capture the most distinctive feature patterns
self.rng_key, subkey = jax.random.split(self.rng_key)
# Step 1: Transpose data for feature-centric operations
X_T = X_jax.T # Shape: (n_features, n_samples)
# Step 2: Calculate feature diversity metrics
# Compute variance of each feature to identify informative dimensions
feature_variance = jnp.var(X_T, axis=1)
# Step 3: Select initial features using variance-weighted sampling
theta_init = jnp.zeros((n_features, self.n_col_archetypes), dtype=jnp.float32)
selected_features = jnp.zeros(self.n_col_archetypes, dtype=jnp.int32)
# Select first feature with probability proportional to variance
self.rng_key, subkey = jax.random.split(self.rng_key)
probs = feature_variance / (jnp.sum(feature_variance) + 1e-10)
first_idx = jax.random.choice(subkey, n_features, p=probs)
selected_features = selected_features.at[0].set(first_idx)
theta_init = theta_init.at[first_idx, 0].set(1.0)
# Step 4: Select remaining features to maximize diversity
for i in range(1, self.n_col_archetypes):
# Calculate minimum distance from each feature to already selected features
min_dists = jnp.ones(n_features) * float("inf")
for j in range(i):
idx = selected_features[j]
# Compute correlation-based distance to capture feature relationships
corr = jnp.abs(jnp.sum(X_T * X_T[idx, jnp.newaxis], axis=1)) / (
jnp.sqrt(jnp.sum(X_T**2, axis=1) * jnp.sum(X_T[idx] ** 2) + 1e-10)
)
# Convert correlation to distance (1 - |corr|)
dists = 1.0 - corr
min_dists = jnp.minimum(min_dists, dists)
# Zero out already selected features
for j in range(i):
idx = selected_features[j]
min_dists = min_dists.at[idx].set(0.0)
# Select feature with maximum minimum distance
next_idx = jnp.argmax(min_dists)
selected_features = selected_features.at[i].set(next_idx)
theta_init = theta_init.at[next_idx, i].set(1.0)
# Step 5: Add controlled noise to promote exploration
self.rng_key, subkey = jax.random.split(self.rng_key)
noise = jax.random.uniform(subkey, theta_init.shape, minval=0.0, maxval=0.05, dtype=jnp.float32)
theta_init = theta_init + noise
# Step 6: Ensure proper normalization to maintain simplex constraints
# Each column must sum to 1 to represent a valid convex combination
theta_init = theta_init / jnp.sum(theta_init, axis=0, keepdims=True)
self.logger.info("Column archetypes initialized with diversity-maximizing strategy")
# Set up optimizer with learning rate schedule for better convergence
# We use a sophisticated learning rate schedule with warmup and decay phases
warmup_steps = 20
decay_steps = 100
# Create a warmup schedule that linearly increases from 0 to peak learning rate
# Use a much lower learning rate to prevent divergence
reduced_lr = self.learning_rate * 0.05 # Reduce learning rate by 20x
warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=reduced_lr, transition_steps=warmup_steps)
# Create a decay schedule that exponentially decays from peak to minimum learning rate
decay_schedule = optax.exponential_decay(
init_value=reduced_lr,
transition_steps=decay_steps,
decay_rate=0.95, # Even slower decay for more stable convergence
end_value=0.000001, # Very low minimum learning rate for fine-grained optimization
staircase=False, # Smooth decay rather than step-wise
)
# Combine the schedules
schedule = optax.join_schedules(schedules=[warmup_schedule, decay_schedule], boundaries=[warmup_steps])
# Create a sophisticated optimizer chain with:
# 1. Gradient clipping to prevent exploding gradients
# 2. Adam optimizer with our custom learning rate schedule
# 3. Weight decay for regularization
optimizer = optax.chain(
optax.clip_by_global_norm(0.5), # More aggressive clipping to prevent divergence
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), # Adam optimizer with standard parameters
optax.add_decayed_weights(weight_decay=1e-6), # Very subtle weight decay
optax.scale_by_schedule(schedule), # Apply our custom learning rate schedule
)
# Initialize parameters
params = {"alpha": alpha_init, "beta": beta_init, "theta": theta_init, "gamma": gamma_init}
opt_state = optimizer.init(params)
# Define update step with JIT compilation for speed
@partial(jax.jit, static_argnums=(3,))
def update_step(
params: dict[str, jnp.ndarray], opt_state: optax.OptState, X: jnp.ndarray, iteration: int
) -> tuple[dict[str, jnp.ndarray], optax.OptState, jnp.ndarray]:
"""Execute a single optimization step."""
# Loss function for current parameters
def loss_fn(params):
return self.loss_function(params, X)
# Compute gradients and update parameters
loss, grads = jax.value_and_grad(loss_fn)(params)
grads = jax.tree.map(lambda g: jnp.clip(g, -1.0, 1.0), grads)
updates, opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
# Maintain simplex constraints
new_params["alpha"] = self.project_row_coefficients(new_params["alpha"])
new_params["gamma"] = self.project_col_coefficients(new_params["gamma"])
# Periodically project archetypes to convex hull boundary
# This is expensive, so we do it every 10 iterations
do_project = iteration % 10 == 0
def project():
return {
"alpha": new_params["alpha"],
"beta": self.project_row_archetypes(new_params["beta"], X),
"theta": self.project_col_archetypes(new_params["theta"], X),
"gamma": new_params["gamma"],
}
def no_project():
return new_params
new_params = jax.lax.cond(do_project, lambda: project(), lambda: no_project())
return new_params, opt_state, loss
# Optimization loop
prev_loss = float("inf")
self.loss_history = []
# Calculate initial loss for debugging
initial_loss = float(self.loss_function(params, X_jax))
self.logger.info(f"Initial loss: {initial_loss:.6f}")
for it in range(self.max_iter):
try:
# Execute update step
params, opt_state, loss = update_step(params, opt_state, X_jax, it)
loss_value = float(loss)
# Check for NaN
if jnp.isnan(loss_value):
self.logger.warning(get_message("warning", "nan_detected", iteration=it))
break
# Record loss
self.loss_history.append(loss_value)
# Track loss and display progress
self.logger.info(f"Iteration {it}, Loss: {loss_value:.6f}")
# Evaluate convergence using both immediate and moving average improvements
relative_improvement = (prev_loss - loss_value) / (prev_loss + 1e-10)
# Tracking using moving average for stable convergence detection
if it >= 10:
recent_losses = self.loss_history[-10:]
loss_ma = sum(recent_losses) / 10
# Use initial loss as reference if no previous MA exists
if "loss_ma_prev" not in locals():
loss_ma_prev = self.loss_history[0]
relative_ma_improvement = (loss_ma_prev - loss_ma) / (loss_ma_prev + 1e-10)
loss_ma_prev = loss_ma
else:
relative_ma_improvement = relative_improvement
loss_ma_prev = prev_loss # Initialize for future iterations
# Check combined convergence criteria
if it > 20 and relative_improvement < self.tol and relative_ma_improvement < self.tol:
self.logger.info(f"Converged at iteration {it}")
break
prev_loss = loss_value
# Display comprehensive progress information at regular intervals
if it % 25 == 0 or it < 5:
# Calculate performance metrics for monitoring optimization trajectory
if len(self.loss_history) > 1:
avg_last_5 = sum(self.loss_history[-min(5, len(self.loss_history)) :]) / min(
5, len(self.loss_history)
)
improvement_rate = (self.loss_history[0] - loss_value) / (it + 1) if it > 0 else 0
self.logger.info(
get_message(
"progress",
"iteration_progress",
current=it,
total=self.max_iter,
loss=loss_value,
avg_last_5=avg_last_5,
improvement_rate=improvement_rate,
)
)
else:
self.logger.info(
get_message(
"progress",
"iteration_progress",
current=it,
total=self.max_iter,
loss=loss_value,
)
)
# Provide in-depth diagnostics at major milestones
if it % 100 == 0 and it > 0:
# Analyze archetype characteristics
alpha_sparsity = jnp.mean(jnp.sum(params["alpha"] > 0.01, axis=1) / params["alpha"].shape[1])
gamma_sparsity = jnp.mean(jnp.sum(params["gamma"] > 0.01, axis=0) / params["gamma"].shape[0])
self.logger.debug(
f" - Alpha sparsity: {float(alpha_sparsity):.4f} | Gamma sparsity: {float(gamma_sparsity):.4f}"
)
self.logger.debug(f" - Learning rate: {float(schedule(it)):.8f}")
# Flag potential convergence issues
if jnp.max(params["alpha"]) > 0.99:
self.logger.warning(
" - Warning: Alpha contains near-one values, may indicate degenerate solution"
)
if jnp.max(params["gamma"]) > 0.99:
self.logger.warning(
" - Warning: Gamma contains near-one values, may indicate degenerate solution"
)
except Exception as e:
self.logger.error(f"Error at iteration {it}: {e!s}")
break
# Final projection of archetypes to ensure they're on the convex hull boundary
params["beta"] = self.project_row_archetypes(jnp.asarray(params["beta"]), X_jax)
params["theta"] = self.project_col_archetypes(jnp.asarray(params["theta"]), X_jax)
# Store final parameters
self.alpha = np.array(params["alpha"])
self.beta = np.array(params["beta"])
self.theta = np.array(params["theta"])
self.gamma = np.array(params["gamma"])
# Calculate biarchetypes (Z = beta·X·theta)
self.biarchetypes = np.array(np.matmul(np.matmul(self.beta, np.asanyarray(X_jax)), self.theta))
# For compatibility with parent class
self.archetypes = np.array(np.matmul(self.beta, X_jax)) # Row archetypes
self.weights = np.array(self.alpha) # Row weights
if len(self.loss_history) > 0:
self.logger.info(f"Final loss: {self.loss_history[-1]:.6f}")
else:
self.logger.warning("No valid loss was recorded")
return self
[docs]
def reconstruct(self, X: np.ndarray | None = None) -> np.ndarray:
"""Reconstruct data from biarchetypal representation.
This method provides the inverse operation of transform(), reconstructing
data points from their biarchetypal weights. This capability serves several
critical purposes:
1. Validation of model quality through reconstruction error assessment
2. Interpretation of what specific archetypes represent in data space
3. Generation of synthetic data by manipulating archetype weights
4. Noise reduction by reconstructing data through dominant archetypes
The method handles both original training data and new data points.
Args:
X: Optional data matrix to reconstruct. If None, uses the training data
Returns:
Reconstructed data matrix in the original feature space
"""
if X is not None:
# Transform new data and reconstruct
alpha, gamma = self.transform(X)
else:
# Use stored weights from training
if self.alpha is None or self.gamma is None:
raise ValueError("Model must be fitted before reconstruction")
alpha, gamma = self.alpha, self.gamma
if self.biarchetypes is None:
raise ValueError("Model must be fitted before reconstruction")
# Reconstruct using biarchetypes: X ≃ alpha·Z·gamma
reconstructed = np.matmul(np.matmul(alpha, self.biarchetypes), gamma)
# Inverse transform if normalization was applied
if self.X_mean is not None and self.X_std is not None:
reconstructed = reconstructed * self.X_std + self.X_mean
return np.asarray(reconstructed)
[docs]
def get_biarchetypes(self) -> np.ndarray:
"""Retrieve the core biarchetypes matrix.
The biarchetypes matrix (Z = β·X·θ) represents the heart of the model,
capturing the essential patterns at the intersection of row and column archetypes.
This matrix provides a compact representation of the data's underlying structure,
with each element representing a specific row-column archetype interaction.
Access to this matrix is essential for visualization, interpretation, and
advanced analysis of the identified patterns.
Returns:
Biarchetypes matrix (n_row_archetypes, n_col_archetypes)
"""
if self.biarchetypes is None:
raise ValueError("Model must be fitted before getting biarchetypes")
return self.biarchetypes
[docs]
def get_row_archetypes(self) -> np.ndarray:
"""Retrieve the row archetypes.
Row archetypes represent extreme patterns in observation space, describing
distinctive types of data points. These archetypes are essential for
understanding the primary modes of variation among observations and provide
the foundation for interpreting data point weights.
In the biarchetypal model, row archetypes are projections of the data matrix
via the beta coefficients (β·X).
Returns:
Row archetypes matrix (n_row_archetypes, n_features)
"""
if self.archetypes is None:
raise ValueError("Model must be fitted before getting row archetypes")
return self.archetypes
[docs]
def get_col_archetypes(self) -> np.ndarray:
"""Retrieve the column archetypes.
Column archetypes represent extreme patterns in feature space, describing
distinctive feature combinations or "feature types." This perspective is
unique to biarchetypal analysis and provides crucial insights about feature
relationships that would be missed in standard archetypal analysis.
These archetypes enable feature-level interpretations and can reveal
coordinated feature behaviors across different data contexts.
Returns:
Column archetypes matrix (n_col_archetypes, n_features)
"""
if self.theta is None or self.gamma is None:
raise ValueError("Model must be fitted before getting column archetypes")
# Modified: Changed the calculation method for column archetypes
# If the original data is not available, generate column archetypes from the shape of theta
if self.theta.shape[0] == self.theta.shape[1]:
# If theta is a square matrix, generate column archetypes similar to an identity matrix
return np.eye(self.theta.shape[0])
else:
# Position each column archetype along the feature space axes
col_archetypes = np.zeros((self.n_col_archetypes, self.theta.shape[0]))
for i in range(min(self.n_col_archetypes, self.theta.shape[0])):
col_archetypes[i, i] = 1.0
return col_archetypes
[docs]
def get_row_weights(self) -> np.ndarray:
"""Retrieve the row coefficients (alpha).
Row weights represent how each data point is composed as a mixture of
row archetypes. These weights are essential for:
1. Understanding which archetype patterns dominate each observation
2. Clustering similar observations based on their archetype compositions
3. Detecting anomalies as points with unusual archetype weights
4. Creating reduced-dimension visualizations based on archetype space
The weights are constrained to be non-negative and sum to 1 (simplex constraint),
making them directly interpretable as proportions.
Returns:
Row weight matrix (n_samples, n_row_archetypes)
"""
if self.alpha is None:
raise ValueError("Model must be fitted before getting row weights")
return self.alpha
[docs]
def get_col_weights(self) -> np.ndarray:
"""Retrieve the column coefficients (gamma).
Column weights represent how each feature is composed as a mixture of
column archetypes. These weights provide unique insights into:
1. Which feature patterns are expressed in each original feature
2. How features group together based on shared archetype influence
3. Feature importance through the lens of archetypal patterns
4. Potential redundancies in the feature space
This feature-space perspective is a distinguishing advantage of biarchetypal
analysis compared to standard archetypal methods.
Returns:
Column weight matrix (n_col_archetypes, n_features)
"""
if self.gamma is None:
raise ValueError("Model must be fitted before getting column weights")
return self.gamma
[docs]
def get_all_archetypes(self) -> tuple[np.ndarray, np.ndarray]:
"""Retrieve both row and column archetypes in a single call.
This convenience method provides access to both directions of archetypal
analysis simultaneously, facilitating comprehensive analysis and visualization
of the dual-perspective patterns. Accessing both archetypes together is
particularly valuable for cross-modal analysis examining relationships
between observation patterns and feature patterns.
Returns:
Tuple of (row_archetypes, column_archetypes) matrices
"""
return self.get_row_archetypes(), self.get_col_archetypes()
[docs]
def get_all_weights(self) -> tuple[np.ndarray, np.ndarray]:
"""Retrieve both row and column weights in a single call.
This convenience method provides access to all weight coefficients
simultaneously, enabling comprehensive analysis of how observations
and features relate to their respective archetypes. This unified view
is particularly valuable for understanding the full biarchetypal
decomposition and how information flows between the row and column spaces.
Returns:
Tuple of (row_weights, column_weights) matrices
"""
return self.get_row_weights(), self.get_col_weights()