"""Quantitative assessment tools for archetypal model validity and performance.
This module provides specialized metrics and visualizations for evaluating archetypal
analysis results. These tools address the critical gap between model fitting and
quality verification by offering:
1. Objective quantification of model performance across multiple dimensions
2. Statistical validation of archetype meaningfulness and separation
3. Specialized measures for interpretability and representational quality
4. Comparative frameworks for model selection and hyperparameter tuning
These capabilities are essential for ensuring model reliability, selecting optimal
configurations, and providing confidence in derived insights - particularly in
scientific, business intelligence, and decision support applications.
"""
import math # Import math module for factorial function
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial import ConvexHull, QhullError
from scipy.spatial.distance import cdist
from scipy.stats import entropy
from sklearn.metrics import davies_bouldin_score, silhouette_score
from ..models.base import ArchetypalAnalysis
[docs]
class ArchetypalAnalysisEvaluator:
"""Comprehensive evaluation suite for validating archetypal analysis quality.
This class provides specialized metrics and visualizations for assessing model
performance across multiple critical dimensions. Rather than relying on a single
metric, it offers a holistic evaluation approach that examines:
- Reconstruction fidelity and information preservation
- Archetype distinctiveness and interpretability
- Geometric properties of the archetype simplex
- Clustering quality and pattern discovery effectiveness
- Feature utilization patterns and importance distributions
This multi-faceted assessment is essential for model validation, hyperparameter
tuning, and ensuring that the archetypal representation provides meaningful
insights into the underlying data structure.
"""
[docs]
def __init__(self, model: ArchetypalAnalysis):
"""Initialize the evaluator with a fitted archetypal model.
Sets up the evaluation framework by extracting and caching key model
properties needed for efficient metric calculation. These properties
include archetype configurations, weight distributions, and dominant
archetype assignments that will be used across multiple evaluation
methods.
Args:
model: Fitted ArchetypalAnalysis model with discovered archetypes
and calculated weights
"""
self.model = model
if model.archetypes is None or model.weights is None:
raise ValueError("Model must be fitted before evaluation")
# Cache some frequently used values
self.n_archetypes = model.archetypes.shape[0]
self.n_features = model.archetypes.shape[1]
self.dominant_archetypes = np.argmax(model.weights, axis=1)
[docs]
def reconstruction_error(self, X: np.ndarray, metric: str = "frobenius") -> float:
"""Quantify how accurately the model reproduces the original data.
This fundamental metric measures the information loss between original
data and its archetypal reconstruction. The reconstruction error serves
several critical purposes:
- Validating that the model captures essential data patterns
- Comparing different archetype counts for optimal complexity
- Identifying potential overfitting or underfitting
- Providing an objective basis for model selection
The implementation offers multiple error metrics to accommodate different
sensitivity needs and statistical preferences.
Args:
X: Original data matrix to reconstruct
metric: Error calculation method:
'frobenius' - Matrix norm (sensitive to outliers)
'mae' - Mean absolute error (more robust)
'mse' - Mean squared error (standard in many contexts)
'relative' - Normalized by data magnitude (for comparison)
Returns:
Calculated reconstruction error (lower values indicate better fit)
"""
X_reconstructed = self.model.reconstruct()
if metric == "frobenius":
# Frobenius norm (default)
return float(np.linalg.norm(X - X_reconstructed, ord="fro"))
elif metric == "mae":
# Mean Absolute Error
return float(np.mean(np.abs(X - X_reconstructed)))
elif metric == "mse":
# Mean Squared Error
return float(np.mean((X - X_reconstructed) ** 2))
elif metric == "relative":
# Relative error
return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.linalg.norm(X, ord="fro"))
else:
raise ValueError(f"Unknown metric: {metric}. Use 'frobenius', 'mae', 'mse', or 'relative'.")
[docs]
def explained_variance(self, X: np.ndarray) -> float:
"""Measure the proportion of data variance captured by the archetypal model.
This intuitive metric expresses model quality as a percentage of total
data variation explained, similar to PCA's explained variance ratio.
This perspective offers several advantages:
- Provides an easily interpretable score between 0-1
- Enables direct comparison with other dimensionality reduction methods
- Helps determine if the chosen number of archetypes is sufficient
- Indicates whether important patterns have been missed
Higher values indicate that the archetypal representation captures
more of the information present in the original data.
Args:
X: Original data matrix for variance calculation
Returns:
Explained variance ratio (0-1, higher values indicate better fit)
"""
X_reconstructed = self.model.reconstruct(X)
# Calculate total variance
total_variance = np.var(X, axis=0).sum()
# Calculate residual variance
residual_variance = np.var(X - X_reconstructed, axis=0).sum()
# Calculate explained variance
explained_var = 1.0 - (residual_variance / total_variance)
return float(explained_var)
[docs]
def dominant_archetype_purity(self) -> dict[str, Any]:
"""Analyze how distinctly samples associate with their primary archetypes.
This metric quantifies how uniquely each sample is represented by a single
archetype rather than being a mixture of many. High purity indicates that:
- Archetypes represent distinct, well-separated patterns in the data
- Samples can be meaningfully assigned to specific archetypes
- The model has discovered genuine structure rather than arbitrary positions
- Classification and interpretation of new samples will be more reliable
Low purity suggests overlapping archetypes or that more archetypes may
be needed to represent the data's inherent structure.
Returns:
Dictionary with purity metrics including:
- Per-archetype purity scores
- Overall dataset purity
- Purity variation statistics
- Raw maximum weight values
"""
if self.model.weights is None:
raise ValueError("Model must be fitted before evaluating purity")
# Get weights for each sample
weights: np.ndarray = self.model.weights
# Get maximum weight for each sample
max_weights = np.max(weights, axis=1)
# Calculate average purity for each archetype
archetype_purity = {}
for i in range(self.n_archetypes):
archetype_mask = self.dominant_archetypes == i
if np.sum(archetype_mask) > 0: # Check if archetype has any assigned samples
avg_purity = np.mean(max_weights[archetype_mask])
archetype_purity[f"Archetype_{i}"] = avg_purity
# Calculate overall purity metrics
overall_purity = np.mean(max_weights)
purity_std = np.std(max_weights)
return {
"archetype_purity": archetype_purity,
"overall_purity": overall_purity,
"purity_std": purity_std,
"max_weights": max_weights,
}
[docs]
def archetype_separation(self) -> dict[str, float]:
"""Measure the geometric distinctiveness between discovered archetypes.
This metric quantifies how well-separated archetypes are in feature space,
which is crucial for interpretability and meaningful pattern detection.
Well-separated archetypes indicate:
- Clear differentiation between discovered patterns
- Minimal redundancy in the archetypal representation
- Stronger interpretability of what each archetype represents
- More robust and stable optimization results
Poor separation suggests potential issues like local minima traps,
excessive archetypes, or inherent pattern similarity in the data.
Returns:
Dictionary with separation metrics including:
- Minimum distance between any two archetypes
- Maximum pairwise distance in the set
- Average inter-archetype distance
- Ratio of minimum to maximum distance (uniformity measure)
"""
# Calculate all pairwise distances between archetypes
archetype_distances = cdist(self.model.archetypes, self.model.archetypes)
# Fill diagonal with NaN to ignore self-distances
np.fill_diagonal(archetype_distances, np.nan)
# Calculate metrics
min_distance = np.nanmin(archetype_distances)
max_distance = np.nanmax(archetype_distances)
mean_distance = np.nanmean(archetype_distances)
return {
"min_distance": min_distance,
"max_distance": max_distance,
"mean_distance": mean_distance,
"distance_ratio": min_distance / max_distance if max_distance > 0 else 0,
}
[docs]
def clustering_metrics(self, X: np.ndarray) -> dict[str, float]:
"""Evaluate the archetypes' effectiveness as cluster centroids.
This analysis bridges archetypal analysis with clustering by treating
dominant archetype assignments as cluster memberships. This perspective
provides critical insights into:
- How well archetypes identify natural groupings in the data
- The coherence of samples dominated by the same archetype
- Separation between different archetype-defined groups
- The comparative quality versus traditional clustering techniques
These metrics help validate that archetypes not only reconstruct the
data accurately but also discover meaningful structural patterns.
Args:
X: Original data matrix for clustering evaluation
Returns:
Dictionary with clustering quality metrics:
- Silhouette score (higher values indicate better-defined clusters)
- Davies-Bouldin index (lower values indicate better separation)
"""
# Need at least 2 archetypes and more samples than archetypes
if self.n_archetypes < 2 or X.shape[0] <= self.n_archetypes:
return {"silhouette": np.nan, "davies_bouldin": np.nan}
try:
# Silhouette score (higher is better)
silhouette = silhouette_score(X, self.dominant_archetypes)
# Davies-Bouldin index (lower is better)
davies_bouldin = davies_bouldin_score(X, self.dominant_archetypes)
return {"silhouette": silhouette, "davies_bouldin": davies_bouldin}
except Exception as e:
print(f"Could not compute clustering metrics: {e!s}")
return {"silhouette": np.nan, "davies_bouldin": np.nan}
[docs]
def archetype_feature_importance(self) -> pd.DataFrame:
"""Identify which features define and distinguish each archetype.
This analysis reveals the characteristic features that make each archetype
unique, translating abstract archetypes into interpretable patterns.
Understanding feature importance enables:
- Interpretation of what each archetype represents in domain terms
- Identification of defining characteristics for each extreme pattern
- Feature selection based on archetypal relevance
- Targeted analysis of specific variables driving pattern differences
The resulting feature importance profiles are essential for deriving
actionable insights and explaining archetypal patterns to stakeholders.
Returns:
DataFrame with normalized feature importance scores for each archetype,
where higher absolute values indicate more distinctive usage
"""
# Get archetypes
archetypes = self.model.archetypes
if archetypes is None:
raise ValueError("Model archetypes must not be None")
# Calculate feature-wise z-scores for each archetype
feature_means = np.mean(archetypes, axis=0)
feature_stds = np.std(archetypes, axis=0)
# Avoid division by zero
feature_stds = np.where(feature_stds < 1e-10, 1.0, feature_stds)
# Calculate z-scores
feature_importance = np.abs((archetypes - feature_means) / feature_stds)
# Create DataFrame
archetype_names = [f"Archetype_{i}" for i in range(self.n_archetypes)]
feature_names = [f"Feature_{i}" for i in range(self.n_features)]
return pd.DataFrame(feature_importance, index=archetype_names, columns=feature_names)
[docs]
def weight_diversity(self) -> dict[str, float]:
"""
Measure how diverse the weight distributions are across samples.
Returns:
Dictionary with diversity metrics
"""
weights = self.model.weights
if weights is None:
raise ValueError("Model weights must not be None")
# Calculate entropy for each sample's weight distribution
sample_entropy = np.array([entropy(w) for w in weights])
# Theoretical maximum entropy for uniform distribution
max_entropy = np.log(self.n_archetypes)
# Normalize entropy (0-1 scale)
normalized_entropy = sample_entropy / max_entropy
return {
"mean_entropy": np.mean(sample_entropy),
"mean_normalized_entropy": np.mean(normalized_entropy),
"entropy_std": np.std(sample_entropy),
"min_entropy": np.min(sample_entropy),
"max_entropy": np.max(sample_entropy),
}
[docs]
def convex_hull_metrics(self) -> dict[str, Any]:
"""
Calculate metrics related to the convex hull formed by the archetypes.
This method evaluates whether the archetypes form a non-degenerate convex hull
by calculating its volume/area and comparing it to the data's convex hull.
Returns:
Dictionary with convex hull metrics including:
- volume/area of the convex hull
- ratio compared to data hull volume/area
- dimensionality of the hull
"""
archetypes = self.model.archetypes
if archetypes is None:
raise ValueError("Model must be fitted before evaluating convex hull")
# Ensure we have enough archetypes to form a convex hull
n_archetypes, n_features = archetypes.shape
min_points_needed = min(n_features + 1, n_archetypes)
hull_metrics: dict[str, Any] = {
"volume": 0.0,
"volume_ratio": 0.0,
"dimensionality": 0,
"is_degenerate": True,
}
# Check if we have enough points to form a hull
if n_archetypes < min_points_needed:
hull_metrics["error"] = (
f"Not enough archetypes ({n_archetypes}) to form a convex hull in {n_features}D space"
)
return hull_metrics
try:
# Calculate convex hull of archetypes
archetype_hull = ConvexHull(archetypes)
hull_metrics["volume"] = archetype_hull.volume
hull_metrics["dimensionality"] = archetype_hull.ndim
hull_metrics["is_degenerate"] = False
# If we have access to the original data, compare to data hull
if hasattr(self.model, "X") and self.model.X is not None:
try:
data_hull = ConvexHull(self.model.X)
hull_metrics["data_volume"] = data_hull.volume
hull_metrics["volume_ratio"] = archetype_hull.volume / data_hull.volume
except QhullError:
# Data might not form a valid convex hull
hull_metrics["data_volume"] = None
hull_metrics["volume_ratio"] = None
except QhullError as e:
# Handle the case where archetypes form a degenerate convex hull
hull_metrics["error"] = f"Degenerate convex hull: {e!s}"
hull_metrics["is_degenerate"] = True
# If hull calculation failed, calculate the n-dimensional simplex volume using determinant
if n_archetypes >= 2: # Need at least 2 points for any meaningful volume
try:
# Center the archetypes
centered = archetypes - np.mean(archetypes, axis=0)
# For 2D case (area)
if n_features == 2 and n_archetypes >= 3:
# Calculate area using Shoelace formula
x = archetypes[:, 0]
y = archetypes[:, 1]
area = 0.5 * np.abs(np.sum(x * np.roll(y, 1) - np.roll(x, 1) * y))
hull_metrics["volume"] = area
hull_metrics["dimensionality"] = 2
# For higher dimensions, estimate volume using matrix determinant
elif n_archetypes >= n_features + 1:
# Select n_features archetypes to form a basis
vectors = centered[1 : n_features + 1] - centered[0]
# Calculate volume of parallelotope
volume = np.abs(np.linalg.det(vectors)) / math.factorial(n_features)
hull_metrics["volume"] = volume
hull_metrics["dimensionality"] = n_features
if hull_metrics["volume"] > 1e-10:
hull_metrics["is_degenerate"] = False
except Exception as calc_err:
# Explicitly convert to string to avoid typing issues
error_message = str(calc_err)
# Use a placeholder numeric value for error cases
hull_metrics["volume"] = 0.0
hull_metrics["calculation_error"] = error_message
return hull_metrics
[docs]
def plot_convex_hull(self, feature_indices: list[int] | None = None, figsize: tuple[int, int] = (10, 8)) -> None:
"""
Plot the convex hull formed by archetypes in 2D or 3D.
Args:
feature_indices: Indices of features to use for visualization (2 or 3 features)
figsize: Size of the figure
"""
archetypes = self.model.archetypes
if archetypes is None:
raise ValueError("Model must be fitted before plotting convex hull")
if feature_indices is None:
feature_indices = [0, 1, 2] if archetypes.shape[1] >= 3 else [0, 1]
if len(feature_indices) not in [2, 3]:
raise ValueError("feature_indices must contain 2 or 3 feature indices for 2D or 3D visualization")
selected_archetypes = archetypes[:, feature_indices]
plt.figure(figsize=figsize)
if len(feature_indices) == 2:
plt.scatter(
selected_archetypes[:, 0],
selected_archetypes[:, 1],
s=100,
c="r",
marker="o",
label="Archetypes",
)
# Try to plot the convex hull
try:
hull = ConvexHull(selected_archetypes)
for simplex in hull.simplices:
plt.plot(selected_archetypes[simplex, 0], selected_archetypes[simplex, 1], "k-")
# Add area information
area = float(hull.volume) # In 2D, volume is area
plt.title(f"Convex Hull of Archetypes (Area: {area:.4f})")
except QhullError:
plt.title("Archetypes (Degenerate Convex Hull)")
# Plot original data if available
if hasattr(self.model, "X") and self.model.X is not None:
data = self.model.X[:, feature_indices]
plt.scatter(data[:, 0], data[:, 1], s=10, alpha=0.5, label="Data")
plt.xlabel(f"Feature {feature_indices[0]}")
plt.ylabel(f"Feature {feature_indices[1]}")
# 3D plot
else:
ax = plt.figure().add_subplot(111, projection="3d")
ax.scatter(
selected_archetypes[:, 0],
selected_archetypes[:, 1],
selected_archetypes[:, 2],
# s=100,
color="r",
marker="o",
label="Archetypes",
)
# Try to plot the convex hull
try:
hull = ConvexHull(selected_archetypes)
for simplex in hull.simplices:
ax.plot(
selected_archetypes[simplex, 0],
selected_archetypes[simplex, 1],
selected_archetypes[simplex, 2],
"k-",
)
# Add volume information
volume = float(hull.volume)
ax.set_title(f"Convex Hull of Archetypes (Volume: {volume:.4f})")
except QhullError:
ax.set_title("Archetypes (Degenerate Convex Hull)")
# Plot original data if available
if hasattr(self.model, "X") and self.model.X is not None:
data = self.model.X[:, feature_indices]
ax.scatter(
data[:, 0],
data[:, 1],
data[:, 2],
# s=10,
alpha=0.3,
color="blue",
label="Data",
)
ax.set_xlabel(f"Feature {feature_indices[0]}")
ax.set_ylabel(f"Feature {feature_indices[1]}")
if hasattr(ax, "set_zlabel"):
ax.set_zlabel(f"Feature {feature_indices[2]}")
plt.legend()
plt.tight_layout()
plt.show()
[docs]
def comprehensive_evaluation(self, X: np.ndarray) -> dict[str, Any]:
"""
Run all evaluation metrics and return comprehensive results.
Args:
X: Original data matrix
Returns:
Dictionary with all evaluation metrics
"""
results = {
"reconstruction": {
"frobenius": self.reconstruction_error(X, "frobenius"),
"mae": self.reconstruction_error(X, "mae"),
"mse": self.reconstruction_error(X, "mse"),
"relative": self.reconstruction_error(X, "relative"),
},
"explained_variance": self.explained_variance(X),
"purity": self.dominant_archetype_purity(),
"separation": self.archetype_separation(),
"clustering": self.clustering_metrics(X),
"diversity": self.weight_diversity(),
"convex_hull": self.convex_hull_metrics(),
}
return results
[docs]
def print_evaluation_report(self, X: np.ndarray) -> None:
"""
Print a comprehensive evaluation report.
Args:
X: Original data matrix
"""
results = self.comprehensive_evaluation(X)
print("\n" + "=" * 50)
print(f"ARCHETYPAL ANALYSIS EVALUATION ({self.n_archetypes} archetypes)")
print("=" * 50)
print("\n1. RECONSTRUCTION METRICS:")
print(f" - Reconstruction Error: {results['reconstruction']['relative']:.4f}")
print(f" - Explained Variance: {results['explained_variance']:.4f}")
print("\n2. ARCHETYPE SEPARATION:")
print(f" - Minimum Distance: {results['separation']['min_distance']:.4f}")
print(f" - Maximum Distance: {results['separation']['max_distance']:.4f}")
print(f" - Mean Distance: {results['separation']['mean_distance']:.4f}")
print(f" - Distance Ratio (min/max): {results['separation']['distance_ratio']:.4f}")
print("\n3. DOMINANT ARCHETYPE PURITY:")
print(f" - Overall Purity: {results['purity']['overall_purity']:.4f}")
print(f" - Purity Std Dev: {results['purity']['purity_std']:.4f}")
print(" - Per-Archetype Purity:")
for archetype, purity in results["purity"]["archetype_purity"].items():
print(f" - {archetype}: {purity:.4f}")
print("\n4. CLUSTERING METRICS:")
if not np.isnan(results["clustering"]["silhouette"]):
print(f" - Silhouette Score: {results['clustering']['silhouette']:.4f}")
print(f" - Davies-Bouldin Index: {results['clustering']['davies_bouldin']:.4f}")
else:
print(" - Clustering metrics not available (insufficient data)")
print("\n5. WEIGHT DIVERSITY:")
print(f" - Mean Entropy: {results['diversity']['mean_entropy']:.4f}")
print(f" - Min Entropy: {results['diversity']['min_entropy']:.4f}")
print(f" - Max Entropy: {results['diversity']['max_entropy']:.4f}")
print("\n6. CONVEX HULL METRICS:")
hull_metrics = results["convex_hull"]
print(f" - Volume/Area: {hull_metrics['volume']:.6f}")
if hull_metrics.get("volume_ratio") is not None:
print(f" - Volume Ratio (vs Data): {hull_metrics['volume_ratio']:.4f}")
print(f" - Dimensionality: {hull_metrics['dimensionality']}")
print(f" - Is Degenerate: {hull_metrics['is_degenerate']}")
if "error" in hull_metrics:
print(f" - Error: {hull_metrics['error']}")
print("\n" + "=" * 50)
# Visualization methods for high-dimensional data
[docs]
def plot_feature_importance_heatmap(self, feature_names: list[str] | None = None) -> None:
"""
Plot heatmap of feature importance across archetypes.
Args:
feature_names: Optional list of feature names
"""
importance_df = self.archetype_feature_importance()
# Rename columns if feature names provided
if feature_names is not None and len(feature_names) == self.n_features:
importance_df = pd.DataFrame(importance_df.values, index=importance_df.index, columns=feature_names)
plt.figure(figsize=(12, 8))
sns.heatmap(importance_df, cmap="viridis", annot=True)
plt.title("Feature Importance Across Archetypes")
plt.xlabel("Features")
plt.ylabel("Archetypes")
plt.tight_layout()
plt.show()
[docs]
def plot_archetype_feature_comparison(self, top_n: int = 5, feature_names: list[str] | None = None) -> None:
"""
Plot radar chart or bar chart comparing top N most important features for each archetype.
Args:
top_n: Number of top features to display
feature_names: Optional list of feature names
"""
importance_df = self.archetype_feature_importance()
# Rename columns if feature names provided
if feature_names is not None and len(feature_names) == self.n_features:
importance_df = pd.DataFrame(importance_df.values, index=importance_df.index, columns=feature_names)
# For each archetype, get the top N most important features
plt.figure(figsize=(15, 4 * ((self.n_archetypes + 1) // 2)))
for i in range(self.n_archetypes):
# Sort features by importance for this archetype
archetype_importance = importance_df.iloc[i].sort_values(ascending=False)
top_features = archetype_importance.head(top_n)
plt.subplot(((self.n_archetypes + 1) // 2), 2, i + 1)
bars = plt.bar(
np.arange(len(top_features)),
top_features.values.astype(float),
tick_label=top_features.index,
color="skyblue",
)
# Add values on top of bars
for bar in bars:
height = bar.get_height()
plt.text(
bar.get_x() + bar.get_width() / 2.0,
height + 0.05,
f"{height:.2f}",
ha="center",
va="bottom",
rotation=0,
)
plt.title(f"Archetype {i}: Top {top_n} Features")
plt.ylim(0, max(top_features.values) * 1.2) # Add headroom for text
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.tight_layout()
plt.show()
[docs]
def plot_weight_distributions(self, bins: int = 20) -> None:
"""
Plot histograms of weight distributions for each archetype.
Args:
bins: Number of histogram bins
"""
weights = self.model.weights
if weights is None:
raise ValueError("Model weights must not be None")
plt.figure(figsize=(15, 4 * ((self.n_archetypes + 1) // 2)))
for i in range(self.n_archetypes):
plt.subplot(((self.n_archetypes + 1) // 2), 2, i + 1)
# Get weights for this archetype
archetype_weights = weights[:, i]
# Plot histogram
plt.hist(archetype_weights, bins=bins, alpha=0.7, color="skyblue")
plt.title(f"Archetype {i} Weight Distribution")
plt.xlabel("Weight")
plt.ylabel("Number of Samples")
# Add statistics
plt.axvline(
np.mean(archetype_weights),
color="r",
linestyle="--",
label=f"Mean: {np.mean(archetype_weights):.3f}",
)
plt.axvline(
np.median(archetype_weights),
color="g",
linestyle="-",
label=f"Median: {np.median(archetype_weights):.3f}",
)
plt.legend()
plt.tight_layout()
plt.show()
[docs]
def plot_purity_distribution(self) -> None:
"""Plot the distribution of dominant archetype weights (purity)."""
purity_data = self.dominant_archetype_purity()
if "max_weights" not in purity_data:
raise ValueError("Max weights data is missing")
max_weights = purity_data["max_weights"]
if max_weights is None:
raise ValueError("Max weights is None")
plt.figure(figsize=(10, 6))
# Plot histogram
plt.hist(max_weights, bins=20, alpha=0.7, color="skyblue")
plt.title("Distribution of Dominant Archetype Weights (Purity)")
plt.xlabel("Maximum Weight")
plt.ylabel("Number of Samples")
# Add statistics
plt.axvline(
np.mean(max_weights),
color="r",
linestyle="--",
label=f"Mean: {np.mean(max_weights):.3f}",
)
plt.axvline(
np.median(max_weights),
color="g",
linestyle="-",
label=f"Median: {np.median(max_weights):.3f}",
)
# Theoretical threshold for uniform weights
uniform_weight = 1.0 / self.n_archetypes
plt.axvline(
uniform_weight,
color="k",
linestyle=":",
label=f"Uniform: {uniform_weight:.3f}",
)
plt.legend()
plt.tight_layout()
plt.show()
[docs]
def plot_distance_matrix(self) -> None:
"""Plot distance matrix between archetypes."""
# Calculate pairwise distances
distances = cdist(self.model.archetypes, self.model.archetypes)
plt.figure(figsize=(10, 8))
# Create heatmap
sns.heatmap(
distances,
annot=True,
cmap="viridis",
xticklabels=[f"A{i}" for i in range(self.n_archetypes)],
yticklabels=[f"A{i}" for i in range(self.n_archetypes)],
)
plt.title("Distance Matrix Between Archetypes")
plt.tight_layout()
plt.show()
[docs]
def plot_entropy_vs_reconstruction(self, X: np.ndarray, n_samples: int = 1000) -> None:
"""
Plot relationship between sample entropy and reconstruction error.
Args:
X: Original data matrix
n_samples: Number of samples to plot (random subset)
"""
weights = self.model.weights
X_reconstructed = self.model.reconstruct()
if weights is None:
raise ValueError("Model weights must not be None")
# Calculate point-wise reconstruction error
point_errors = np.sqrt(np.sum((X - X_reconstructed) ** 2, axis=1))
# Calculate entropy for each point
entropies = np.array([entropy(w) for w in weights])
# Normalize to maximum possible entropy
max_entropy = np.log(self.n_archetypes)
normalized_entropies = entropies / max_entropy
# Select subset if needed
if n_samples < len(entropies) and n_samples > 0:
indices = np.random.choice(len(entropies), size=n_samples, replace=False)
point_errors = point_errors[indices]
normalized_entropies = normalized_entropies[indices]
dominant_archetypes = self.dominant_archetypes[indices]
else:
dominant_archetypes = self.dominant_archetypes
plt.figure(figsize=(10, 8))
# Scatter plot colored by dominant archetype
scatter = plt.scatter(
normalized_entropies,
point_errors,
c=dominant_archetypes,
cmap="viridis",
alpha=0.6,
s=30,
)
# Add color legend
legend = plt.legend(*scatter.legend_elements(), title="Dominant Archetype")
plt.gca().add_artist(legend)
# Add correlation coefficient
corr = np.corrcoef(normalized_entropies, point_errors)[0, 1]
plt.text(
0.05,
0.95,
f"Correlation: {corr:.3f}",
transform=plt.gca().transAxes,
bbox={"facecolor": "white", "alpha": 0.8},
)
plt.xlabel("Normalized Entropy (Diversity)")
plt.ylabel("Reconstruction Error")
plt.title("Relationship Between Weight Diversity and Reconstruction Error")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
[docs]
class BiarchetypalAnalysisEvaluator:
"""
Evaluator for Biarchetypal Analysis results.
Provides metrics and visualizations to assess model quality for biarchetypal models,
which use two sets of archetypes to represent data.
"""
[docs]
def __init__(self, model):
"""
Initialize the evaluator.
Args:
model: Fitted BiarchetypalAnalysis model
"""
from ..models.biarchetypes import BiarchetypalAnalysis
if not isinstance(model, BiarchetypalAnalysis):
raise TypeError("Model must be a BiarchetypalAnalysis instance")
self.model = model
# Check if model is fitted
if model.alpha is None or model.beta is None or model.theta is None or model.gamma is None:
raise ValueError("Model must be fitted before evaluation")
# Cache some frequently used values
self.n_archetypes_first = model.n_row_archetypes
self.n_archetypes_second = model.n_col_archetypes
self.n_features = model.theta.shape[0] # n_features
# Calculate dominant archetypes for each set
self.dominant_archetypes_first = np.argmax(model.alpha, axis=1)
self.dominant_archetypes_second = np.argmax(model.gamma, axis=0)
[docs]
def reconstruction_error(self, X: np.ndarray, metric: str = "frobenius") -> float:
"""
Calculate the reconstruction error of the model.
Args:
X: Data matrix
metric: Error metric to use ('frobenius', 'mae', 'mse', or 'relative')
Returns:
Reconstruction error value
"""
X_reconstructed = self.model.reconstruct(X)
if metric == "frobenius":
return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.sqrt(X.shape[0]))
elif metric == "mse":
return float(np.mean((X - X_reconstructed) ** 2))
elif metric == "mae":
return float(np.mean(np.abs(X - X_reconstructed)))
elif metric == "relative":
return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.linalg.norm(X, ord="fro"))
else:
raise ValueError(f"Unknown metric: {metric}")
[docs]
def explained_variance(self, X: np.ndarray) -> float:
"""
Calculate the explained variance of the model.
Args:
X: Data matrix
Returns:
Explained variance (0-1)
"""
X_reconstructed = self.model.reconstruct(X)
# Calculate total variance
total_variance = np.var(X, axis=0).sum()
# Calculate residual variance
residual_variance = np.var(X - X_reconstructed, axis=0).sum()
# Calculate explained variance
explained_var = 1.0 - (residual_variance / total_variance)
return float(explained_var)
[docs]
def archetype_separation(self):
"""Calculate separation metrics between archetypes.
Returns:
Dictionary of separation metrics
"""
# Calculate distances between first set archetypes
distances_first = cdist(self.model.alpha, self.model.alpha)
np.fill_diagonal(distances_first, np.inf) # Ignore self-distances
# Calculate distances between second set archetypes
distances_second = cdist(self.model.gamma, self.model.gamma)
np.fill_diagonal(distances_second, np.inf) # Ignore self-distances
# Calculate metrics for first set
metrics_first = {}
if np.any(distances_first != np.inf):
metrics_first = {
"mean_distance_first": np.mean(distances_first[distances_first != np.inf]),
"min_distance_first": np.min(distances_first[distances_first != np.inf]),
"max_distance_first": np.max(distances_first[distances_first != np.inf]),
}
else:
metrics_first = {
"mean_distance_first": 0.0,
"min_distance_first": 0.0,
"max_distance_first": 0.0,
}
# Calculate metrics for second set
metrics_second = {}
if np.any(distances_second != np.inf):
metrics_second = {
"mean_distance_second": np.mean(distances_second[distances_second != np.inf]),
"min_distance_second": np.min(distances_second[distances_second != np.inf]),
"max_distance_second": np.max(distances_second[distances_second != np.inf]),
}
else:
metrics_second = {
"mean_distance_second": 0.0,
"min_distance_second": 0.0,
"max_distance_second": 0.0,
}
# Calculate cross metrics
metrics_cross = {
"mean_cross_distance": 0.0,
"min_cross_distance": 0.0,
"max_cross_distance": 0.0,
}
# Combine all metrics
metrics = {**metrics_first, **metrics_second, **metrics_cross}
return metrics
[docs]
def dominant_archetype_purity(self) -> dict:
"""
Calculate purity metrics for dominant archetypes.
Returns:
Dictionary of purity metrics
"""
# Calculate purity for first set
archetype_counts_first = np.bincount(self.dominant_archetypes_first, minlength=self.n_archetypes_first)
archetype_purity_first = archetype_counts_first / np.sum(archetype_counts_first)
# Calculate purity for second set
archetype_counts_second = np.bincount(self.dominant_archetypes_second, minlength=self.n_archetypes_second)
archetype_purity_second = archetype_counts_second / np.sum(archetype_counts_second)
# Calculate overall metrics
return {
"archetype_purity_first": archetype_purity_first,
"archetype_purity_second": archetype_purity_second,
"overall_purity_first": np.max(archetype_purity_first) if archetype_purity_first.size > 0 else 0,
"overall_purity_second": np.max(archetype_purity_second) if archetype_purity_second.size > 0 else 0,
"purity_std_first": np.std(archetype_purity_first),
"purity_std_second": np.std(archetype_purity_second),
}
[docs]
def weight_diversity(self) -> dict:
"""
Calculate diversity metrics for archetype weights.
Returns:
Dictionary of diversity metrics
"""
if self.model.alpha is None or self.model.gamma is None:
raise ValueError("Model must be fitted before calculating weight diversity")
# Calculate entropy for first set weights
entropies_first = -np.sum(self.model.alpha * np.log2(self.model.alpha + 1e-10), axis=1)
max_entropy_first = np.log2(self.model.alpha.shape[1])
# Add check to prevent division by zero
if max_entropy_first > 0:
normalized_entropies_first = entropies_first / max_entropy_first
else:
normalized_entropies_first = np.zeros_like(entropies_first)
# Calculate entropy for second set weights
entropies_second = -np.sum(self.model.gamma * np.log2(self.model.gamma + 1e-10), axis=0)
max_entropy_second = np.log2(self.model.gamma.shape[0])
# Add check to prevent division by zero
if max_entropy_second > 0:
normalized_entropies_second = entropies_second / max_entropy_second
else:
normalized_entropies_second = np.zeros_like(entropies_second)
# Calculate metrics
metrics = {
"mean_entropy_first": np.mean(entropies_first),
"entropy_std_first": np.std(entropies_first),
"max_entropy_first": np.max(entropies_first),
"mean_normalized_entropy_first": np.mean(normalized_entropies_first),
"mean_entropy_second": np.mean(entropies_second),
"entropy_std_second": np.std(entropies_second),
"max_entropy_second": np.max(entropies_second),
"mean_normalized_entropy_second": np.mean(normalized_entropies_second),
}
return metrics
[docs]
def comprehensive_evaluation(self, X: np.ndarray) -> dict:
"""
Perform a comprehensive evaluation of the model.
Args:
X: Data matrix
Returns:
Dictionary of evaluation metrics
"""
# Calculate reconstruction metrics
reconstruction_metrics = {
"frobenius": self.reconstruction_error(X, metric="frobenius"),
"mse": self.reconstruction_error(X, metric="mse"),
"mae": self.reconstruction_error(X, metric="mae"),
"relative": self.reconstruction_error(X, metric="relative"),
"explained_variance": self.explained_variance(X),
}
# Get other metrics
separation_metrics = self.archetype_separation()
purity_metrics = self.dominant_archetype_purity()
diversity_metrics = self.weight_diversity()
# Combine all metrics
results = {
"reconstruction": reconstruction_metrics,
"separation": separation_metrics,
"purity": purity_metrics,
"diversity": diversity_metrics,
}
# Calculate other metrics
if "clustering" not in results:
results["clustering"] = {"silhouette": np.nan, "davies_bouldin": np.nan}
if "convex_hull" not in results:
results["convex_hull"] = {
"volume": 0.0,
"dimensionality": 0,
"is_degenerate": True,
"volume_ratio": 0.0,
}
return results
[docs]
def print_evaluation_report(self, X: np.ndarray) -> None:
"""
Print a comprehensive evaluation report.
Args:
X: Original data matrix
"""
results = self.comprehensive_evaluation(X)
print("\n" + "=" * 50)
print(
f"ARCHETYPAL ANALYSIS EVALUATION ({self.n_archetypes_first} archetypes, {self.n_archetypes_second} archetypes)"
)
print("=" * 50)
print("\n1. RECONSTRUCTION METRICS:")
print(f" - Reconstruction Error: {results['reconstruction']['relative']:.4f}")
print(f" - Explained Variance: {results['reconstruction']['explained_variance']:.4f}")
print("\n2. ARCHETYPE SEPARATION:")
print(f" - Minimum Distance (First Set): {results['separation']['min_distance_first']:.4f}")
print(f" - Maximum Distance (First Set): {results['separation']['max_distance_first']:.4f}")
print(f" - Mean Distance (First Set): {results['separation']['mean_distance_first']:.4f}")
print(f" - Minimum Distance (Second Set): {results['separation']['min_distance_second']:.4f}")
print(f" - Maximum Distance (Second Set): {results['separation']['max_distance_second']:.4f}")
print(f" - Mean Distance (Second Set): {results['separation']['mean_distance_second']:.4f}")
print("\n3. DOMINANT ARCHETYPE PURITY:")
print(f" - Overall Purity (First Set): {results['purity']['overall_purity_first']:.4f}")
print(f" - Overall Purity (Second Set): {results['purity']['overall_purity_second']:.4f}")
print(" - Per-Archetype Purity (First Set):")
for archetype, purity in results["purity"]["archetype_purity_first"].items():
print(f" - {archetype}: {purity:.4f}")
print(" - Per-Archetype Purity (Second Set):")
for archetype, purity in results["purity"]["archetype_purity_second"].items():
print(f" - {archetype}: {purity:.4f}")
print("\n4. CLUSTERING METRICS:")
if not np.isnan(results["clustering"]["silhouette"]):
print(f" - Silhouette Score: {results['clustering']['silhouette']:.4f}")
print(f" - Davies-Bouldin Index: {results['clustering']['davies_bouldin']:.4f}")
else:
print(" - Clustering metrics not available (insufficient data)")
print("\n5. WEIGHT DIVERSITY:")
print(f" - Mean Normalized Entropy (First Set): {results['diversity']['mean_normalized_entropy_first']:.4f}")
print(
f" - Mean Normalized Entropy (Second Set): {results['diversity']['mean_normalized_entropy_second']:.4f}"
)
print("\n6. CONVEX HULL METRICS:")
hull_metrics = results["convex_hull"]
print(f" - Volume/Area (First Set): {hull_metrics['volume']:.6f}")
print(f" - Volume/Area (Second Set): {hull_metrics['volume']:.6f}")
if hull_metrics.get("volume_ratio") is not None:
print(f" - Volume Ratio (vs Data) (First Set): {hull_metrics['volume_ratio']:.4f}")
print(f" - Volume Ratio (vs Data) (Second Set): {hull_metrics['volume_ratio']:.4f}")
print(f" - Dimensionality (First Set): {hull_metrics['dimensionality']}")
print(f" - Dimensionality (Second Set): {hull_metrics['dimensionality']}")
print(f" - Is Degenerate (First Set): {hull_metrics['is_degenerate']}")
print(f" - Is Degenerate (Second Set): {hull_metrics['is_degenerate']}")
if "error" in hull_metrics:
print(f" - Error: {hull_metrics['error']}")
print("\n" + "=" * 50)
[docs]
def print_summary(self, results: dict):
"""Print a summary of the evaluation results.
Args:
results: Dictionary of evaluation results
"""
print("\n=== Biarchetypal Analysis Evaluation Summary ===")
print("\n--- Separation Metrics ---")
print(f"Mean Distance (First Set): {results['separation']['mean_distance_first']:.4f}")
print(f"Mean Distance (Second Set): {results['separation']['mean_distance_second']:.4f}")
# print(f"Mean Cross-Set Distance: {results['separation']['mean_cross_distance']:.4f}")
print("\n--- Purity Metrics ---")
print(f"Overall Purity (First Set): {results['purity']['overall_purity_first']:.4f}")
print(f"Overall Purity (Second Set): {results['purity']['overall_purity_second']:.4f}")
print("\n--- Weight Diversity Metrics ---")
print(f"Mean Normalized Entropy (First Set): {results['diversity']['mean_normalized_entropy_first']:.4f}")
print(f"Mean Normalized Entropy (Second Set): {results['diversity']['mean_normalized_entropy_second']:.4f}")