"""Advanced tools for extracting meaningful insights from archetypal representations.
This module provides specialized metrics and techniques for translating mathematical
archetypal models into domain-relevant interpretations. These tools address the
critical challenge of making abstract archetypal patterns understandable by:
1. Quantifying interpretability characteristics of discovered archetypes
2. Revealing feature-level insights about what each archetype represents
3. Determining optimal archetype configurations for maximum meaningfulness
4. Assessing stability and reliability of derived interpretations
These capabilities are essential for bridging the gap between algorithmic discovery
and practical application, enabling stakeholders to leverage archetypal analysis
for meaningful decision-making, pattern discovery, and knowledge extraction.
"""
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from ..models.archetypes import ImprovedArchetypalAnalysis
from ..models.base import ArchetypalAnalysis
from ..models.biarchetypes import BiarchetypalAnalysis
from ..tools.evaluation import ArchetypalAnalysisEvaluator
[docs]
class ArchetypalAnalysisInterpreter:
"""Advanced interpreter for extracting meaningful insights from archetypal models.
This class provides specialized metrics and visualization tools for translating
abstract mathematical archetypes into understandable, domain-relevant patterns.
Beyond basic evaluation, this interpreter focuses on:
- Quantifying interpretability characteristics of archetypes
- Determining optimal archetype configurations for maximum meaningfulness
- Assessing feature importance and pattern distinctiveness
- Measuring stability and consistency of discovered archetypes
These capabilities address the critical challenge of making archetypal analysis
results accessible and actionable for domain experts and decision-makers,
particularly in exploratory analysis and knowledge discovery applications.
"""
[docs]
def __init__(self, models_dict: dict[int, ArchetypalAnalysis] | None = None) -> None:
"""Initialize the interpreter with optional model collection.
This constructor can either create an empty interpreter for later model
addition or initialize with a pre-fitted collection of models with
different archetype counts. The latter enables comparative analysis
across model complexities for optimal configuration selection.
Args:
models_dict: Optional dictionary mapping archetype counts to fitted
models for comparative interpretation
"""
self.models_dict: dict[int, ArchetypalAnalysis] | None = models_dict or None
self.results: dict[int, dict[str, Any]] | None = None
[docs]
def add_model(self, n_archetypes: int, model: ArchetypalAnalysis) -> "ArchetypalAnalysisInterpreter":
"""Register a fitted archetypal model for interpretation.
This method builds the interpreter's model collection incrementally,
allowing comparative analysis across different archetype configurations.
Each model is validated to ensure it has been properly fitted before
inclusion in the comparative framework.
Args:
n_archetypes: Number of archetypes in the model (key for retrieval)
model: Fitted archetypal model to include in the analysis
Returns:
Self - for method chaining
"""
if model.archetypes is None or model.weights is None:
raise ValueError("Model must be fitted before adding to interpreter")
if self.models_dict is None:
self.models_dict = {}
self.models_dict[n_archetypes] = model
return self
[docs]
def feature_distinctiveness(self, archetypes: np.ndarray) -> np.ndarray:
"""Quantify how uniquely each archetype represents specific feature patterns.
This interpretability metric measures how well each archetype captures unique
feature patterns not represented by other archetypes. High distinctiveness
indicates:
- The archetype represents a truly unique data pattern
- Features have meaningful peak values in this archetype
- The archetype makes a non-redundant contribution to the model
- Interpretation can focus on specific distinguishing characteristics
Low distinctiveness suggests potential redundancy or that the archetype
represents a subtle pattern variation rather than a fundamentally distinct type.
Args:
archetypes: Archetype matrix (n_archetypes, n_features)
Returns:
Array of distinctiveness scores for each archetype, where higher
values indicate more distinctive feature utilization
"""
n_archetypes, n_features = archetypes.shape
distinctiveness_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Calculate the difference between this archetype's values and the maximum values of other archetypes
other_archetypes = np.delete(archetypes, i, axis=0)
max_others = np.max(other_archetypes, axis=0) if len(other_archetypes) > 0 else np.zeros(n_features)
distinctiveness = archetypes[i] - max_others
# Sum the positive differences (features that are particularly prominent in this archetype)
distinctiveness_scores[i] = np.sum(np.maximum(0, distinctiveness))
return distinctiveness_scores
[docs]
def sparsity_coefficient(self, archetypes: np.ndarray, percentile: float = 80) -> np.ndarray:
"""Measure interpretability through feature utilization concentration.
This metric quantifies how selectively each archetype utilizes features,
based on the cognitive science principle that humans can most effectively
interpret patterns defined by a small number of prominent characteristics.
High sparsity indicates:
- The archetype focuses on a specific subset of features
- Interpretation can highlight a manageable number of key attributes
- The pattern has clear defining characteristics
- Domain experts can more easily understand and label the archetype
Low sparsity suggests a complex pattern utilizing many features, which
may be harder to interpret but potentially more faithful to complex phenomena.
Args:
archetypes: Archetype matrix (n_archetypes, n_features)
percentile: Threshold for considering features as prominent (higher
values produce more selective feature identification)
Returns:
Array of sparsity scores for each archetype (higher values indicate
more focused feature utilization and better interpretability)
"""
n_archetypes, n_features = archetypes.shape
sparsity_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Calculate feature importance (e.g., Z-scores)
importance = np.abs(archetypes[i])
if np.std(importance) > 1e-10: # Standardize only if variance is non-zero
importance = (importance - np.mean(importance)) / np.std(importance)
# Calculate the proportion of features above a certain percentile
threshold = np.percentile(importance, percentile)
prominent_features = np.sum(importance >= threshold)
sparsity_scores[i] = prominent_features / n_features
# Lower scores indicate higher sparsity (better interpretability)
return 1 - sparsity_scores
[docs]
def cluster_purity(self, weights: np.ndarray, threshold: float = 0.6) -> tuple[np.ndarray, float]:
"""Assess archetype interpretability through assignment clarity.
This metric evaluates how cleanly each archetype captures a distinct
subset of data points, based on the principle that interpretable archetypes
should represent clear, distinguishable patterns. High purity indicates:
- The archetype represents a coherent, well-defined pattern
- Data points can be meaningfully assigned to specific archetypes
- The model has discovered genuine structure rather than arbitrary positions
- Users can confidently interpret new samples through dominant archetypes
Low purity suggests archetypes may be capturing overlapping patterns or
that the underlying data lacks clear archetypal structure.
Args:
weights: Weight matrix (n_samples, n_archetypes)
threshold: Minimum weight to consider an archetype dominant
Returns:
Tuple containing: (1) purity scores for each archetype and
(2) average purity across all archetypes
"""
n_samples, n_archetypes = weights.shape
purity_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Count samples where this archetype is dominant
dominant_samples = np.sum(weights[:, i] >= threshold)
purity_scores[i] = dominant_samples / n_samples
return np.asarray(purity_scores), float(np.mean(purity_scores))
[docs]
def feature_consistency(
self,
X: np.ndarray,
n_archetypes: int,
n_trials: int = 5,
top_k: int = 5,
random_seed: int = 42,
) -> np.ndarray:
"""Evaluate the stability of feature importance across multiple initializations.
This reliability assessment measures whether the same features consistently
define each archetype across different optimization runs. Consistency is
critical for interpretation because:
- Unstable feature importance undermines confidence in interpretations
- Reliable patterns indicate genuine data structure rather than optimization artifacts
- Consistent archetypes enable more dependable knowledge extraction
- Higher consistency justifies stronger claims about discovered patterns
This analysis is particularly important in exploratory contexts where
results guide hypothesis generation or decision-making.
Args:
X: Data matrix for fitting trial models
n_archetypes: Archetype count to evaluate
n_trials: Number of different initializations for consistency testing
top_k: Number of top features to consider in consistency calculation
random_seed: Base random seed (incremented for each trial)
Returns:
Array of consistency scores for each archetype (higher values
indicate more stable feature importance across initializations)
"""
importance_matrices = []
for i in range(n_trials):
model = ImprovedArchetypalAnalysis(n_archetypes=n_archetypes, random_seed=random_seed + i)
model.fit(X)
# Calculate feature importance matrix
evaluator = ArchetypalAnalysisEvaluator(model)
importance = evaluator.archetype_feature_importance().values
# Store feature importance rankings for each archetype
rankings = np.argsort(-importance, axis=1)
importance_matrices.append(rankings)
# Calculate consistency of top-K features
consistency_scores = np.zeros(n_archetypes)
top_k = min(top_k, X.shape[1])
for i in range(n_archetypes):
overlap_count = 0.0
for j in range(n_trials):
for k in range(j + 1, n_trials):
set1 = set(importance_matrices[j][i, :top_k])
set2 = set(importance_matrices[k][i, :top_k])
overlap = len(set1.intersection(set2))
overlap_count += overlap / top_k
total_comparisons = (n_trials * (n_trials - 1)) / 2
consistency_scores[i] = overlap_count / total_comparisons if total_comparisons > 0 else 0
return consistency_scores
[docs]
def evaluate_all_models(self, X: np.ndarray) -> dict[int, dict[str, Any]]:
"""
Evaluate interpretability metrics for all models.
Args:
X: Original data matrix
Returns:
Dictionary of results per number of archetypes
"""
if not self.models_dict:
raise ValueError("No models available for evaluation")
if self.results is None:
self.results = {}
for k, model in self.models_dict.items():
if model.archetypes is None:
raise ValueError(f"Model with {k} archetypes must be fitted before evaluation")
if model.weights is None:
raise ValueError(f"Model with {k} archetypes must be fitted before evaluation")
if k not in self.results:
self.results[k] = {}
# Calculate various interpretability metrics
distinctiveness = self.feature_distinctiveness(model.archetypes)
sparsity = self.sparsity_coefficient(model.archetypes)
purity, avg_purity = self.cluster_purity(model.weights)
# Calculate average metrics
avg_distinctiveness = np.mean(distinctiveness)
avg_sparsity = np.mean(sparsity)
# Calculate overall interpretability score (higher is better)
interpretability_score = (avg_distinctiveness + avg_sparsity + avg_purity) / 3
self.results[k] = {
"distinctiveness": distinctiveness,
"sparsity": sparsity,
"purity": purity,
"avg_distinctiveness": avg_distinctiveness,
"avg_sparsity": avg_sparsity,
"avg_purity": avg_purity,
"interpretability_score": interpretability_score,
}
# Calculate information gain
try:
gains = self.information_gain(X)
for k, gain in gains:
if k in self.results:
self.results[k]["information_gain"] = gain
# Evaluate balance between interpretability and information gain
for k in list(self.models_dict.keys())[1:]: # Skip the first model
if "information_gain" in self.results[k]:
gain = self.results[k]["information_gain"]
interp = self.results[k]["interpretability_score"]
# Calculate balance score using harmonic mean
if gain + interp > 0:
self.results[k]["balance_score"] = 2 * (gain * interp) / (gain + interp)
else:
self.results[k]["balance_score"] = 0
except Exception as e:
print(f"Warning: Could not compute information gain: {e}")
return self.results
[docs]
def suggest_optimal_archetypes(self, method: str = "balance") -> int:
"""
Suggest optimal number of archetypes based on interpretability metrics.
Args:
method: Method to use for selection ('balance', 'interpretability', or 'information_gain')
Returns:
Optimal number of archetypes
"""
if not self.results:
raise ValueError("Must run evaluate_all_models() first")
if self.models_dict is None:
raise ValueError("No models available for optimal archetype selection")
if method == "balance" and all("balance_score" in self.results[k] for k in list(self.models_dict.keys())[1:]):
scores = {k: self.results[k]["balance_score"] for k in list(self.models_dict.keys())[1:]}
best_k = max(scores, key=lambda k: scores[k])
elif method == "interpretability":
scores = {k: self.results[k]["interpretability_score"] for k in list(self.models_dict.keys())}
best_k = max(scores, key=lambda k: scores[k])
elif method == "information_gain" and all(
"information_gain" in self.results[k] for k in list(self.models_dict.keys())[1:]
):
# Detect decay in information gain (elbow method)
ks = sorted(k for k in list(self.models_dict.keys()) if k > min(list(self.models_dict.keys())))
gains = [self.results[k]["information_gain"] for k in ks]
# Calculate differences in information gain
gain_diffs = np.diff(gains)
if len(gain_diffs) > 0:
# Detect the largest decrease
elbow_idx = np.argmin(gain_diffs)
best_k = ks[elbow_idx + 1] # +1 because diff reduces array size by 1
else:
best_k = min(self.models_dict.keys())
else:
raise ValueError(f"Method '{method}' not applicable with current results")
return int(best_k)
[docs]
def plot_interpretability_metrics(self):
"""Plot interpretability metrics for different numbers of archetypes."""
if not self.results:
raise ValueError("Must run evaluate_all_models() first")
ks = sorted(self.results.keys())
# Prepare metrics for plotting
avg_distinctiveness = [self.results[k]["avg_distinctiveness"] for k in ks]
avg_sparsity = [self.results[k]["avg_sparsity"] for k in ks]
avg_purity = [self.results[k]["avg_purity"] for k in ks]
interpretability = [self.results[k]["interpretability_score"] for k in ks]
information_gain = []
for k in ks[1:]: # Skip first k as it has no information gain
information_gain.append(self.results[k].get("information_gain", np.nan))
balance_scores = []
for k in ks[1:]: # Skip first k
balance_scores.append(self.results[k].get("balance_score", np.nan))
# Create plots
_, axes = plt.subplots(3, 1, figsize=(12, 15))
# Plot interpretability metrics
axes[0].plot(ks, avg_distinctiveness, "o-", label="Distinctiveness")
axes[0].plot(ks, avg_sparsity, "s-", label="Sparsity")
axes[0].plot(ks, avg_purity, "^-", label="Purity")
axes[0].plot(ks, interpretability, "D-", label="Overall Interpretability")
axes[0].set_xlabel("Number of Archetypes")
axes[0].set_ylabel("Score")
axes[0].set_title("Interpretability Metrics vs Number of Archetypes")
axes[0].legend()
axes[0].grid(True)
# Plot information gain
if len(information_gain) > 0 and not all(np.isnan(information_gain)):
axes[1].plot(ks[1:], information_gain, "o-")
axes[1].set_xlabel("Number of Archetypes")
axes[1].set_ylabel("Information Gain")
axes[1].set_title("Information Gain from Adding Archetypes")
axes[1].grid(True)
else:
axes[1].text(
0.5,
0.5,
"No information gain data available",
horizontalalignment="center",
verticalalignment="center",
)
# Plot balance score
if len(balance_scores) > 0 and not all(np.isnan(balance_scores)):
axes[2].plot(ks[1:], balance_scores, "o-")
axes[2].set_xlabel("Number of Archetypes")
axes[2].set_ylabel("Balance Score")
axes[2].set_title("Interpretability-Information Gain Balance")
axes[2].grid(True)
# Highlight best k according to balance score
if not all(np.isnan(balance_scores)):
best_idx = np.nanargmax(balance_scores)
best_k = ks[1:][best_idx]
axes[2].axvline(best_k, color="r", linestyle="--", label=f"Optimal k={best_k}")
axes[2].legend()
else:
axes[2].text(
0.5,
0.5,
"No balance score data available",
horizontalalignment="center",
verticalalignment="center",
)
plt.tight_layout()
plt.show()
[docs]
class BiarchetypalAnalysisInterpreter:
"""
Interpreter for Biarchetypal Analysis results, focusing on interpretability metrics.
Provides quantitative measures for biarchetype interpretability and optimal number selection.
"""
[docs]
def __init__(self, models_dict: dict[tuple[int, int], BiarchetypalAnalysis] | None = None) -> None:
"""
Initialize the interpreter.
Args:
models_dict: Optional dictionary of {n_archetypes_first, n_archetypes_second: model} pairs
"""
self.models_dict: dict[tuple[int, int], BiarchetypalAnalysis] = models_dict or {}
self.results: dict[tuple[int, int], dict[str, Any]] = {}
[docs]
def add_model(
self, n_archetypes_first: int, n_archetypes_second: int, model: BiarchetypalAnalysis
) -> "BiarchetypalAnalysisInterpreter":
"""Add a fitted model to the interpreter."""
# Verify that the model is fitted by using the get_all_archetypes method
try:
model.get_all_archetypes()
except ValueError as e:
raise ValueError(f"Model must be fitted before adding to interpreter: {e}") from e
self.models_dict[n_archetypes_first, n_archetypes_second] = model
return self
[docs]
def feature_distinctiveness(self, archetypes: np.ndarray) -> np.ndarray:
"""
Calculate how distinctive each archetype is in terms of feature values.
Args:
archetypes: Archetype matrix (n_archetypes, n_features)
Returns:
Array of distinctiveness scores for each archetype
"""
n_archetypes, n_features = archetypes.shape
distinctiveness_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Calculate the difference between this archetype's values and the maximum values of other archetypes
other_archetypes = np.delete(archetypes, i, axis=0)
max_others = np.max(other_archetypes, axis=0) if len(other_archetypes) > 0 else np.zeros(n_features)
distinctiveness = archetypes[i] - max_others
# Sum the positive differences (features that are particularly prominent in this archetype)
distinctiveness_scores[i] = np.sum(np.maximum(0, distinctiveness))
return distinctiveness_scores
[docs]
def sparsity_coefficient(self, archetypes: np.ndarray, percentile: float = 80) -> np.ndarray:
"""
Calculate sparsity of each archetype's feature representation.
Args:
archetypes: Archetype matrix (n_archetypes, n_features)
percentile: Percentile threshold for considering features as prominent
Returns:
Array of sparsity scores for each archetype (higher is more interpretable)
"""
n_archetypes, n_features = archetypes.shape
sparsity_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Calculate feature importance (e.g., Z-scores)
importance = np.abs(archetypes[i])
if np.std(importance) > 1e-10: # Standardize only if variance is non-zero
importance = (importance - np.mean(importance)) / np.std(importance)
# Calculate proportion of features above the specified percentile
threshold = np.percentile(importance, percentile)
prominent_features = np.sum(importance >= threshold)
sparsity_scores[i] = prominent_features / n_features
# Lower scores indicate higher sparsity (better interpretability)
return 1 - sparsity_scores
[docs]
def cluster_purity(self, weights: np.ndarray, threshold: float = 0.6) -> tuple[np.ndarray, float]:
"""
Calculate purity of each archetype's associated data points.
Args:
weights: Weight matrix (n_samples, n_archetypes)
threshold: Threshold for considering an archetype as dominant
Returns:
Tuple of purity scores per archetype, average purity
"""
n_samples, n_archetypes = weights.shape
purity_scores = np.zeros(n_archetypes)
for i in range(n_archetypes):
# Count samples where this archetype is dominant
dominant_samples = np.sum(weights[:, i] >= threshold)
purity_scores[i] = dominant_samples / n_samples
return purity_scores, float(np.mean(purity_scores))
[docs]
def evaluate_all_models(self, X: np.ndarray) -> dict[tuple[int, int], dict[str, Any]]:
"""
Evaluate interpretability metrics for all models.
Args:
X: Original data matrix
Returns:
Dictionary of results per combination of archetypes
"""
if not self.models_dict:
raise ValueError("No models available for evaluation")
self.results = {}
for (k1, k2), model in self.models_dict.items():
try:
# Retrieve row and column archetypes using the get_all_archetypes method
archetypes_first, archetypes_second = model.get_all_archetypes()
# Retrieve row and column weights using the get_all_weights method
weights_first, weights_second = model.get_all_weights()
except ValueError as e:
raise ValueError(f"Model with archetypes ({k1}, {k2}) must be fitted before evaluation: {e}") from e
# First archetype set interpretability metrics
distinctiveness_first = self.feature_distinctiveness(np.array(archetypes_first))
sparsity_first = self.sparsity_coefficient(np.array(archetypes_first))
purity_first, avg_purity_first = self.cluster_purity(np.array(weights_first))
# Second archetype set interpretability metrics
distinctiveness_second = self.feature_distinctiveness(np.array(archetypes_second))
sparsity_second = self.sparsity_coefficient(np.array(archetypes_second))
purity_second, avg_purity_second = self.cluster_purity(np.array(weights_second))
# Calculate averages
avg_distinctiveness_first = np.mean(distinctiveness_first)
avg_sparsity_first = np.mean(sparsity_first)
avg_distinctiveness_second = np.mean(distinctiveness_second)
avg_sparsity_second = np.mean(sparsity_second)
# Interpretability scores (higher is better)
interpretability_first = (avg_distinctiveness_first + avg_sparsity_first + avg_purity_first) / 3
interpretability_second = (avg_distinctiveness_second + avg_sparsity_second + avg_purity_second) / 3
# Combined score for both sets
combined_interpretability = (interpretability_first + interpretability_second) / 2
# Calculate reconstruction error
X_recon = model.reconstruct(X)
recon_error = np.mean(np.sum((X - X_recon) ** 2, axis=1))
self.results[k1, k2] = {
# First archetype set
"distinctiveness_first": distinctiveness_first,
"sparsity_first": sparsity_first,
"purity_first": purity_first,
"avg_distinctiveness_first": avg_distinctiveness_first,
"avg_sparsity_first": avg_sparsity_first,
"avg_purity_first": avg_purity_first,
"interpretability_first": interpretability_first,
# Second archetype set
"distinctiveness_second": distinctiveness_second,
"sparsity_second": sparsity_second,
"purity_second": purity_second,
"avg_distinctiveness_second": avg_distinctiveness_second,
"avg_sparsity_second": avg_sparsity_second,
"avg_purity_second": avg_purity_second,
"interpretability_second": interpretability_second,
# Combined scores
"combined_interpretability": combined_interpretability,
"reconstruction_error": recon_error,
}
# Calculate information gain
self.compute_information_gain(X)
return self.results
[docs]
def suggest_optimal_biarchetypes(self, method: str = "balance") -> tuple[int, int]:
"""
Suggest optimal archetype number combination based on interpretability metrics.
Args:
method: Method to use for selection ('balance', 'interpretability', or 'information_gain')
Returns:
Optimal combination of n_archetypes_first, n_archetypes_second
"""
if not self.results:
raise ValueError("Must run evaluate_all_models() first")
if method == "balance":
# Only use models that have a balance_score
scores: dict[tuple[int, int], float] = {}
for k in self.models_dict:
if "balance_score" in self.results[k]:
scores[k] = self.results[k]["balance_score"]
if scores: # Ensure scores is not empty
best_k = max(scores.items(), key=lambda x: x[1])[0]
else:
# Fall back to interpretability if balance scores aren't available
return self.suggest_optimal_biarchetypes(method="interpretability")
elif method == "interpretability":
scores = {k: self.results[k]["combined_interpretability"] for k in self.models_dict}
best_k = max(scores.items(), key=lambda x: x[1])[0]
elif method == "information_gain":
# Only use models that have information_gain
scores = {}
min_k = min(self.models_dict.keys(), key=lambda x: x[0] + x[1])
for k in self.models_dict:
if k != min_k and "information_gain" in self.results[k]:
scores[k] = self.results[k]["information_gain"]
if scores: # Ensure scores is not empty
best_k = max(scores.items(), key=lambda x: x[1])[0]
else:
# Fall back to interpretability if information gain isn't available
return self.suggest_optimal_biarchetypes(method="interpretability")
else:
raise ValueError(f"Method '{method}' not applicable with current results")
return best_k
[docs]
def plot_interpretability_heatmap(self) -> plt.Figure:
"""
Plot heatmaps of interpretability metrics for different archetype number combinations.
Returns:
The matplotlib figure object
"""
if not self.results:
raise ValueError("Must run evaluate_all_models() first")
# Get available archetype number combinations
k1_values = sorted({k1 for k1, _ in self.models_dict})
k2_values = sorted({k2 for _, k2 in self.models_dict})
# Store interpretability scores in matrix form
interpretability_matrix = np.zeros((len(k1_values), len(k2_values)))
balance_matrix = np.zeros((len(k1_values), len(k2_values)))
error_matrix = np.zeros((len(k1_values), len(k2_values)))
# Prepare data
for i, k1 in enumerate(k1_values):
for j, k2 in enumerate(k2_values):
if (k1, k2) in self.results:
interpretability_matrix[i, j] = self.results[k1, k2]["combined_interpretability"]
if "balance_score" in self.results[k1, k2]:
balance_matrix[i, j] = self.results[k1, k2]["balance_score"]
error_matrix[i, j] = self.results[k1, k2]["reconstruction_error"]
# Create plots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
# Interpretability score heatmap
sns.heatmap(
interpretability_matrix,
annot=True,
fmt=".3f",
cmap="viridis",
xticklabels=k2_values,
yticklabels=k1_values,
ax=axes[0],
)
axes[0].set_xlabel("Number of Second Archetypes")
axes[0].set_ylabel("Number of First Archetypes")
axes[0].set_title("Combined Interpretability Score")
# Balance score heatmap
sns.heatmap(
balance_matrix,
annot=True,
fmt=".3f",
cmap="coolwarm",
xticklabels=k2_values,
yticklabels=k1_values,
ax=axes[1],
)
axes[1].set_xlabel("Number of Second Archetypes")
axes[1].set_ylabel("Number of First Archetypes")
axes[1].set_title("Interpretability-Information Gain Balance")
# Reconstruction error heatmap
sns.heatmap(
error_matrix,
annot=True,
fmt=".3f",
cmap="rocket_r",
xticklabels=k2_values,
yticklabels=k1_values,
ax=axes[2],
)
axes[2].set_xlabel("Number of Second Archetypes")
axes[2].set_ylabel("Number of First Archetypes")
axes[2].set_title("Reconstruction Error")
plt.tight_layout()
return fig