Source code for archetypax.tools.evaluation

"""Evaluation metrics for Archetypal Analysis."""

import math  # Import math module for factorial function
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.spatial import ConvexHull, QhullError
from scipy.spatial.distance import cdist
from scipy.stats import entropy
from sklearn.metrics import davies_bouldin_score, silhouette_score

from ..models.base import ArchetypalAnalysis


[docs] class ArchetypalAnalysisEvaluator: """ Evaluator for Archetypal Analysis results, especially for high-dimensional data. Provides metrics and visualizations to assess model quality. """
[docs] def __init__(self, model: ArchetypalAnalysis): """ Initialize the evaluator. Args: model: Fitted ArchetypalAnalysis model """ self.model = model if model.archetypes is None or model.weights is None: raise ValueError("Model must be fitted before evaluation") # Cache some frequently used values self.n_archetypes = model.archetypes.shape[0] self.n_features = model.archetypes.shape[1] self.dominant_archetypes = np.argmax(model.weights, axis=1)
[docs] def reconstruction_error(self, X: np.ndarray, metric: str = "frobenius") -> float: """ Calculate the reconstruction error of the model. Args: X: Data matrix metric: Error metric to use ('frobenius', 'mae', 'mse', or 'relative') Returns: Reconstruction error value """ X_reconstructed = self.model.reconstruct() if metric == "frobenius": # Frobenius norm (default) return float(np.linalg.norm(X - X_reconstructed, ord="fro")) elif metric == "mae": # Mean Absolute Error return float(np.mean(np.abs(X - X_reconstructed))) elif metric == "mse": # Mean Squared Error return float(np.mean((X - X_reconstructed) ** 2)) elif metric == "relative": # Relative error return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.linalg.norm(X, ord="fro")) else: raise ValueError(f"Unknown metric: {metric}. Use 'frobenius', 'mae', 'mse', or 'relative'.")
[docs] def explained_variance(self, X: np.ndarray) -> float: """ Calculate the explained variance of the model. Args: X: Data matrix Returns: Explained variance (0-1) """ X_reconstructed = self.model.reconstruct(X) # Calculate total variance total_variance = np.var(X, axis=0).sum() # Calculate residual variance residual_variance = np.var(X - X_reconstructed, axis=0).sum() # Calculate explained variance explained_var = 1.0 - (residual_variance / total_variance) return float(explained_var)
[docs] def dominant_archetype_purity(self) -> dict[str, Any]: """ Analyze how dominant each archetype is for its assigned samples. Returns: Dictionary with purity metrics """ if self.model.weights is None: raise ValueError("Model must be fitted before evaluating purity") # Get weights for each sample weights: np.ndarray = self.model.weights # Get maximum weight for each sample max_weights = np.max(weights, axis=1) # Calculate average purity for each archetype archetype_purity = {} for i in range(self.n_archetypes): archetype_mask = self.dominant_archetypes == i if np.sum(archetype_mask) > 0: # Check if archetype has any assigned samples avg_purity = np.mean(max_weights[archetype_mask]) archetype_purity[f"Archetype_{i}"] = avg_purity # Calculate overall purity metrics overall_purity = np.mean(max_weights) purity_std = np.std(max_weights) return { "archetype_purity": archetype_purity, "overall_purity": overall_purity, "purity_std": purity_std, "max_weights": max_weights, }
[docs] def archetype_separation(self) -> dict[str, float]: """ Measure how well-separated the archetypes are. Returns: Dictionary with separation metrics """ # Calculate all pairwise distances between archetypes archetype_distances = cdist(self.model.archetypes, self.model.archetypes) # Fill diagonal with NaN to ignore self-distances np.fill_diagonal(archetype_distances, np.nan) # Calculate metrics min_distance = np.nanmin(archetype_distances) max_distance = np.nanmax(archetype_distances) mean_distance = np.nanmean(archetype_distances) return { "min_distance": min_distance, "max_distance": max_distance, "mean_distance": mean_distance, "distance_ratio": min_distance / max_distance if max_distance > 0 else 0, }
[docs] def clustering_metrics(self, X: np.ndarray) -> dict[str, float]: """ Calculate clustering quality metrics by using dominant archetypes as cluster assignments. Args: X: Original data matrix Returns: Dictionary with clustering metrics """ # Need at least 2 archetypes and more samples than archetypes if self.n_archetypes < 2 or X.shape[0] <= self.n_archetypes: return {"silhouette": np.nan, "davies_bouldin": np.nan} try: # Silhouette score (higher is better) silhouette = silhouette_score(X, self.dominant_archetypes) # Davies-Bouldin index (lower is better) davies_bouldin = davies_bouldin_score(X, self.dominant_archetypes) return {"silhouette": silhouette, "davies_bouldin": davies_bouldin} except Exception as e: print(f"Could not compute clustering metrics: {e!s}") return {"silhouette": np.nan, "davies_bouldin": np.nan}
[docs] def archetype_feature_importance(self) -> pd.DataFrame: """ Analyze which features are most important for each archetype. Returns: DataFrame with feature importance for each archetype """ # Get archetypes archetypes = self.model.archetypes if archetypes is None: raise ValueError("Model archetypes must not be None") # Calculate feature-wise z-scores for each archetype feature_means = np.mean(archetypes, axis=0) feature_stds = np.std(archetypes, axis=0) # Avoid division by zero feature_stds = np.where(feature_stds < 1e-10, 1.0, feature_stds) # Calculate z-scores feature_importance = np.abs((archetypes - feature_means) / feature_stds) # Create DataFrame archetype_names = [f"Archetype_{i}" for i in range(self.n_archetypes)] feature_names = [f"Feature_{i}" for i in range(self.n_features)] return pd.DataFrame(feature_importance, index=archetype_names, columns=feature_names)
[docs] def weight_diversity(self) -> dict[str, float]: """ Measure how diverse the weight distributions are across samples. Returns: Dictionary with diversity metrics """ weights = self.model.weights if weights is None: raise ValueError("Model weights must not be None") # Calculate entropy for each sample's weight distribution sample_entropy = np.array([entropy(w) for w in weights]) # Theoretical maximum entropy for uniform distribution max_entropy = np.log(self.n_archetypes) # Normalize entropy (0-1 scale) normalized_entropy = sample_entropy / max_entropy return { "mean_entropy": np.mean(sample_entropy), "mean_normalized_entropy": np.mean(normalized_entropy), "entropy_std": np.std(sample_entropy), "min_entropy": np.min(sample_entropy), "max_entropy": np.max(sample_entropy), }
[docs] def convex_hull_metrics(self) -> dict[str, Any]: """ Calculate metrics related to the convex hull formed by the archetypes. This method evaluates whether the archetypes form a non-degenerate convex hull by calculating its volume/area and comparing it to the data's convex hull. Returns: Dictionary with convex hull metrics including: - volume/area of the convex hull - ratio compared to data hull volume/area - dimensionality of the hull """ archetypes = self.model.archetypes if archetypes is None: raise ValueError("Model must be fitted before evaluating convex hull") # Ensure we have enough archetypes to form a convex hull n_archetypes, n_features = archetypes.shape min_points_needed = min(n_features + 1, n_archetypes) hull_metrics: dict[str, Any] = { "volume": 0.0, "volume_ratio": 0.0, "dimensionality": 0, "is_degenerate": True, } # Check if we have enough points to form a hull if n_archetypes < min_points_needed: hull_metrics["error"] = ( f"Not enough archetypes ({n_archetypes}) to form a convex hull in {n_features}D space" ) return hull_metrics try: # Calculate convex hull of archetypes archetype_hull = ConvexHull(archetypes) hull_metrics["volume"] = archetype_hull.volume hull_metrics["dimensionality"] = archetype_hull.ndim hull_metrics["is_degenerate"] = False # If we have access to the original data, compare to data hull if hasattr(self.model, "X") and self.model.X is not None: try: data_hull = ConvexHull(self.model.X) hull_metrics["data_volume"] = data_hull.volume hull_metrics["volume_ratio"] = archetype_hull.volume / data_hull.volume except QhullError: # Data might not form a valid convex hull hull_metrics["data_volume"] = None hull_metrics["volume_ratio"] = None except QhullError as e: # Handle the case where archetypes form a degenerate convex hull hull_metrics["error"] = f"Degenerate convex hull: {e!s}" hull_metrics["is_degenerate"] = True # If hull calculation failed, calculate the n-dimensional simplex volume using determinant if n_archetypes >= 2: # Need at least 2 points for any meaningful volume try: # Center the archetypes centered = archetypes - np.mean(archetypes, axis=0) # For 2D case (area) if n_features == 2 and n_archetypes >= 3: # Calculate area using Shoelace formula x = archetypes[:, 0] y = archetypes[:, 1] area = 0.5 * np.abs(np.sum(x * np.roll(y, 1) - np.roll(x, 1) * y)) hull_metrics["volume"] = area hull_metrics["dimensionality"] = 2 # For higher dimensions, estimate volume using matrix determinant elif n_archetypes >= n_features + 1: # Select n_features archetypes to form a basis vectors = centered[1 : n_features + 1] - centered[0] # Calculate volume of parallelotope volume = np.abs(np.linalg.det(vectors)) / math.factorial(n_features) hull_metrics["volume"] = volume hull_metrics["dimensionality"] = n_features if hull_metrics["volume"] > 1e-10: hull_metrics["is_degenerate"] = False except Exception as calc_err: # Explicitly convert to string to avoid typing issues error_message = str(calc_err) # Use a placeholder numeric value for error cases hull_metrics["volume"] = 0.0 hull_metrics["calculation_error"] = error_message return hull_metrics
[docs] def plot_convex_hull(self, feature_indices: list[int] | None = None, figsize: tuple[int, int] = (10, 8)) -> None: """ Plot the convex hull formed by archetypes in 2D or 3D. Args: feature_indices: Indices of features to use for visualization (2 or 3 features) figsize: Size of the figure """ archetypes = self.model.archetypes if archetypes is None: raise ValueError("Model must be fitted before plotting convex hull") if feature_indices is None: feature_indices = [0, 1, 2] if archetypes.shape[1] >= 3 else [0, 1] if len(feature_indices) not in [2, 3]: raise ValueError("feature_indices must contain 2 or 3 feature indices for 2D or 3D visualization") selected_archetypes = archetypes[:, feature_indices] plt.figure(figsize=figsize) if len(feature_indices) == 2: plt.scatter( selected_archetypes[:, 0], selected_archetypes[:, 1], s=100, c="r", marker="o", label="Archetypes", ) # Try to plot the convex hull try: hull = ConvexHull(selected_archetypes) for simplex in hull.simplices: plt.plot(selected_archetypes[simplex, 0], selected_archetypes[simplex, 1], "k-") # Add area information area = float(hull.volume) # In 2D, volume is area plt.title(f"Convex Hull of Archetypes (Area: {area:.4f})") except QhullError: plt.title("Archetypes (Degenerate Convex Hull)") # Plot original data if available if hasattr(self.model, "X") and self.model.X is not None: data = self.model.X[:, feature_indices] plt.scatter(data[:, 0], data[:, 1], s=10, alpha=0.5, label="Data") plt.xlabel(f"Feature {feature_indices[0]}") plt.ylabel(f"Feature {feature_indices[1]}") # 3D plot else: ax = plt.figure().add_subplot(111, projection="3d") ax.scatter( selected_archetypes[:, 0], selected_archetypes[:, 1], selected_archetypes[:, 2], # s=100, color="r", marker="o", label="Archetypes", ) # Try to plot the convex hull try: hull = ConvexHull(selected_archetypes) for simplex in hull.simplices: ax.plot( selected_archetypes[simplex, 0], selected_archetypes[simplex, 1], selected_archetypes[simplex, 2], "k-", ) # Add volume information volume = float(hull.volume) ax.set_title(f"Convex Hull of Archetypes (Volume: {volume:.4f})") except QhullError: ax.set_title("Archetypes (Degenerate Convex Hull)") # Plot original data if available if hasattr(self.model, "X") and self.model.X is not None: data = self.model.X[:, feature_indices] ax.scatter( data[:, 0], data[:, 1], data[:, 2], # s=10, alpha=0.3, color="blue", label="Data", ) ax.set_xlabel(f"Feature {feature_indices[0]}") ax.set_ylabel(f"Feature {feature_indices[1]}") if hasattr(ax, "set_zlabel"): ax.set_zlabel(f"Feature {feature_indices[2]}") plt.legend() plt.tight_layout() plt.show()
[docs] def comprehensive_evaluation(self, X: np.ndarray) -> dict[str, Any]: """ Run all evaluation metrics and return comprehensive results. Args: X: Original data matrix Returns: Dictionary with all evaluation metrics """ results = { "reconstruction": { "frobenius": self.reconstruction_error(X, "frobenius"), "mae": self.reconstruction_error(X, "mae"), "mse": self.reconstruction_error(X, "mse"), "relative": self.reconstruction_error(X, "relative"), }, "explained_variance": self.explained_variance(X), "purity": self.dominant_archetype_purity(), "separation": self.archetype_separation(), "clustering": self.clustering_metrics(X), "diversity": self.weight_diversity(), "convex_hull": self.convex_hull_metrics(), } return results
[docs] def print_evaluation_report(self, X: np.ndarray) -> None: """ Print a comprehensive evaluation report. Args: X: Original data matrix """ results = self.comprehensive_evaluation(X) print("\n" + "=" * 50) print(f"ARCHETYPAL ANALYSIS EVALUATION ({self.n_archetypes} archetypes)") print("=" * 50) print("\n1. RECONSTRUCTION METRICS:") print(f" - Reconstruction Error: {results['reconstruction']['relative']:.4f}") print(f" - Explained Variance: {results['explained_variance']:.4f}") print("\n2. ARCHETYPE SEPARATION:") print(f" - Minimum Distance: {results['separation']['min_distance']:.4f}") print(f" - Maximum Distance: {results['separation']['max_distance']:.4f}") print(f" - Mean Distance: {results['separation']['mean_distance']:.4f}") print(f" - Distance Ratio (min/max): {results['separation']['distance_ratio']:.4f}") print("\n3. DOMINANT ARCHETYPE PURITY:") print(f" - Overall Purity: {results['purity']['overall_purity']:.4f}") print(f" - Purity Std Dev: {results['purity']['purity_std']:.4f}") print(" - Per-Archetype Purity:") for archetype, purity in results["purity"]["archetype_purity"].items(): print(f" - {archetype}: {purity:.4f}") print("\n4. CLUSTERING METRICS:") if not np.isnan(results["clustering"]["silhouette"]): print(f" - Silhouette Score: {results['clustering']['silhouette']:.4f}") print(f" - Davies-Bouldin Index: {results['clustering']['davies_bouldin']:.4f}") else: print(" - Clustering metrics not available (insufficient data)") print("\n5. WEIGHT DIVERSITY:") print(f" - Mean Entropy: {results['diversity']['mean_entropy']:.4f}") print(f" - Min Entropy: {results['diversity']['min_entropy']:.4f}") print(f" - Max Entropy: {results['diversity']['max_entropy']:.4f}") print("\n6. CONVEX HULL METRICS:") hull_metrics = results["convex_hull"] print(f" - Volume/Area: {hull_metrics['volume']:.6f}") if hull_metrics.get("volume_ratio") is not None: print(f" - Volume Ratio (vs Data): {hull_metrics['volume_ratio']:.4f}") print(f" - Dimensionality: {hull_metrics['dimensionality']}") print(f" - Is Degenerate: {hull_metrics['is_degenerate']}") if "error" in hull_metrics: print(f" - Error: {hull_metrics['error']}") print("\n" + "=" * 50)
# Visualization methods for high-dimensional data
[docs] def plot_feature_importance_heatmap(self, feature_names: list[str] | None = None) -> None: """ Plot heatmap of feature importance across archetypes. Args: feature_names: Optional list of feature names """ importance_df = self.archetype_feature_importance() # Rename columns if feature names provided if feature_names is not None and len(feature_names) == self.n_features: importance_df = pd.DataFrame(importance_df.values, index=importance_df.index, columns=feature_names) plt.figure(figsize=(12, 8)) sns.heatmap(importance_df, cmap="viridis", annot=True) plt.title("Feature Importance Across Archetypes") plt.xlabel("Features") plt.ylabel("Archetypes") plt.tight_layout() plt.show()
[docs] def plot_archetype_feature_comparison(self, top_n: int = 5, feature_names: list[str] | None = None) -> None: """ Plot radar chart or bar chart comparing top N most important features for each archetype. Args: top_n: Number of top features to display feature_names: Optional list of feature names """ importance_df = self.archetype_feature_importance() # Rename columns if feature names provided if feature_names is not None and len(feature_names) == self.n_features: importance_df = pd.DataFrame(importance_df.values, index=importance_df.index, columns=feature_names) # For each archetype, get the top N most important features plt.figure(figsize=(15, 4 * ((self.n_archetypes + 1) // 2))) for i in range(self.n_archetypes): # Sort features by importance for this archetype archetype_importance = importance_df.iloc[i].sort_values(ascending=False) top_features = archetype_importance.head(top_n) plt.subplot(((self.n_archetypes + 1) // 2), 2, i + 1) bars = plt.bar( np.arange(len(top_features)), top_features.values.astype(float), tick_label=top_features.index, color="skyblue", ) # Add values on top of bars for bar in bars: height = bar.get_height() plt.text( bar.get_x() + bar.get_width() / 2.0, height + 0.05, f"{height:.2f}", ha="center", va="bottom", rotation=0, ) plt.title(f"Archetype {i}: Top {top_n} Features") plt.ylim(0, max(top_features.values) * 1.2) # Add headroom for text plt.xticks(rotation=45, ha="right") plt.tight_layout() plt.tight_layout() plt.show()
[docs] def plot_weight_distributions(self, bins: int = 20) -> None: """ Plot histograms of weight distributions for each archetype. Args: bins: Number of histogram bins """ weights = self.model.weights if weights is None: raise ValueError("Model weights must not be None") plt.figure(figsize=(15, 4 * ((self.n_archetypes + 1) // 2))) for i in range(self.n_archetypes): plt.subplot(((self.n_archetypes + 1) // 2), 2, i + 1) # Get weights for this archetype archetype_weights = weights[:, i] # Plot histogram plt.hist(archetype_weights, bins=bins, alpha=0.7, color="skyblue") plt.title(f"Archetype {i} Weight Distribution") plt.xlabel("Weight") plt.ylabel("Number of Samples") # Add statistics plt.axvline( np.mean(archetype_weights), color="r", linestyle="--", label=f"Mean: {np.mean(archetype_weights):.3f}", ) plt.axvline( np.median(archetype_weights), color="g", linestyle="-", label=f"Median: {np.median(archetype_weights):.3f}", ) plt.legend() plt.tight_layout() plt.show()
[docs] def plot_purity_distribution(self) -> None: """Plot the distribution of dominant archetype weights (purity).""" purity_data = self.dominant_archetype_purity() if "max_weights" not in purity_data: raise ValueError("Max weights data is missing") max_weights = purity_data["max_weights"] if max_weights is None: raise ValueError("Max weights is None") plt.figure(figsize=(10, 6)) # Plot histogram plt.hist(max_weights, bins=20, alpha=0.7, color="skyblue") plt.title("Distribution of Dominant Archetype Weights (Purity)") plt.xlabel("Maximum Weight") plt.ylabel("Number of Samples") # Add statistics plt.axvline( np.mean(max_weights), color="r", linestyle="--", label=f"Mean: {np.mean(max_weights):.3f}", ) plt.axvline( np.median(max_weights), color="g", linestyle="-", label=f"Median: {np.median(max_weights):.3f}", ) # Theoretical threshold for uniform weights uniform_weight = 1.0 / self.n_archetypes plt.axvline( uniform_weight, color="k", linestyle=":", label=f"Uniform: {uniform_weight:.3f}", ) plt.legend() plt.tight_layout() plt.show()
[docs] def plot_distance_matrix(self) -> None: """Plot distance matrix between archetypes.""" # Calculate pairwise distances distances = cdist(self.model.archetypes, self.model.archetypes) plt.figure(figsize=(10, 8)) # Create heatmap sns.heatmap( distances, annot=True, cmap="viridis", xticklabels=[f"A{i}" for i in range(self.n_archetypes)], yticklabels=[f"A{i}" for i in range(self.n_archetypes)], ) plt.title("Distance Matrix Between Archetypes") plt.tight_layout() plt.show()
[docs] def plot_entropy_vs_reconstruction(self, X: np.ndarray, n_samples: int = 1000) -> None: """ Plot relationship between sample entropy and reconstruction error. Args: X: Original data matrix n_samples: Number of samples to plot (random subset) """ weights = self.model.weights X_reconstructed = self.model.reconstruct() if weights is None: raise ValueError("Model weights must not be None") # Calculate point-wise reconstruction error point_errors = np.sqrt(np.sum((X - X_reconstructed) ** 2, axis=1)) # Calculate entropy for each point entropies = np.array([entropy(w) for w in weights]) # Normalize to maximum possible entropy max_entropy = np.log(self.n_archetypes) normalized_entropies = entropies / max_entropy # Select subset if needed if n_samples < len(entropies) and n_samples > 0: indices = np.random.choice(len(entropies), size=n_samples, replace=False) point_errors = point_errors[indices] normalized_entropies = normalized_entropies[indices] dominant_archetypes = self.dominant_archetypes[indices] else: dominant_archetypes = self.dominant_archetypes plt.figure(figsize=(10, 8)) # Scatter plot colored by dominant archetype scatter = plt.scatter( normalized_entropies, point_errors, c=dominant_archetypes, cmap="viridis", alpha=0.6, s=30, ) # Add color legend legend = plt.legend(*scatter.legend_elements(), title="Dominant Archetype") plt.gca().add_artist(legend) # Add correlation coefficient corr = np.corrcoef(normalized_entropies, point_errors)[0, 1] plt.text( 0.05, 0.95, f"Correlation: {corr:.3f}", transform=plt.gca().transAxes, bbox={"facecolor": "white", "alpha": 0.8}, ) plt.xlabel("Normalized Entropy (Diversity)") plt.ylabel("Reconstruction Error") plt.title("Relationship Between Weight Diversity and Reconstruction Error") plt.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] class BiarchetypalAnalysisEvaluator: """ Evaluator for Biarchetypal Analysis results. Provides metrics and visualizations to assess model quality for biarchetypal models, which use two sets of archetypes to represent data. """
[docs] def __init__(self, model): """ Initialize the evaluator. Args: model: Fitted BiarchetypalAnalysis model """ from ..models.biarchetypes import BiarchetypalAnalysis if not isinstance(model, BiarchetypalAnalysis): raise TypeError("Model must be a BiarchetypalAnalysis instance") self.model = model # Check if model is fitted if model.alpha is None or model.beta is None or model.theta is None or model.gamma is None: raise ValueError("Model must be fitted before evaluation") # Cache some frequently used values self.n_archetypes_first = model.n_row_archetypes self.n_archetypes_second = model.n_col_archetypes self.n_features = model.theta.shape[0] # n_features # Calculate dominant archetypes for each set self.dominant_archetypes_first = np.argmax(model.alpha, axis=1) self.dominant_archetypes_second = np.argmax(model.gamma, axis=0)
[docs] def reconstruction_error(self, X: np.ndarray, metric: str = "frobenius") -> float: """ Calculate the reconstruction error of the model. Args: X: Data matrix metric: Error metric to use ('frobenius', 'mae', 'mse', or 'relative') Returns: Reconstruction error value """ X_reconstructed = self.model.reconstruct(X) if metric == "frobenius": return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.sqrt(X.shape[0])) elif metric == "mse": return float(np.mean((X - X_reconstructed) ** 2)) elif metric == "mae": return float(np.mean(np.abs(X - X_reconstructed))) elif metric == "relative": return float(np.linalg.norm(X - X_reconstructed, ord="fro") / np.linalg.norm(X, ord="fro")) else: raise ValueError(f"Unknown metric: {metric}")
[docs] def explained_variance(self, X: np.ndarray) -> float: """ Calculate the explained variance of the model. Args: X: Data matrix Returns: Explained variance (0-1) """ X_reconstructed = self.model.reconstruct(X) # Calculate total variance total_variance = np.var(X, axis=0).sum() # Calculate residual variance residual_variance = np.var(X - X_reconstructed, axis=0).sum() # Calculate explained variance explained_var = 1.0 - (residual_variance / total_variance) return float(explained_var)
[docs] def archetype_separation(self): """Calculate separation metrics between archetypes. Returns: Dictionary of separation metrics """ # Calculate distances between first set archetypes distances_first = cdist(self.model.alpha, self.model.alpha) np.fill_diagonal(distances_first, np.inf) # Ignore self-distances # Calculate distances between second set archetypes distances_second = cdist(self.model.gamma, self.model.gamma) np.fill_diagonal(distances_second, np.inf) # Ignore self-distances # Calculate metrics for first set metrics_first = {} if np.any(distances_first != np.inf): metrics_first = { "mean_distance_first": np.mean(distances_first[distances_first != np.inf]), "min_distance_first": np.min(distances_first[distances_first != np.inf]), "max_distance_first": np.max(distances_first[distances_first != np.inf]), } else: metrics_first = { "mean_distance_first": 0.0, "min_distance_first": 0.0, "max_distance_first": 0.0, } # Calculate metrics for second set metrics_second = {} if np.any(distances_second != np.inf): metrics_second = { "mean_distance_second": np.mean(distances_second[distances_second != np.inf]), "min_distance_second": np.min(distances_second[distances_second != np.inf]), "max_distance_second": np.max(distances_second[distances_second != np.inf]), } else: metrics_second = { "mean_distance_second": 0.0, "min_distance_second": 0.0, "max_distance_second": 0.0, } # Calculate cross metrics metrics_cross = { "mean_cross_distance": 0.0, "min_cross_distance": 0.0, "max_cross_distance": 0.0, } # Combine all metrics metrics = {**metrics_first, **metrics_second, **metrics_cross} return metrics
[docs] def dominant_archetype_purity(self) -> dict: """ Calculate purity metrics for dominant archetypes. Returns: Dictionary of purity metrics """ # Calculate purity for first set archetype_counts_first = np.bincount(self.dominant_archetypes_first, minlength=self.n_archetypes_first) archetype_purity_first = archetype_counts_first / np.sum(archetype_counts_first) # Calculate purity for second set archetype_counts_second = np.bincount(self.dominant_archetypes_second, minlength=self.n_archetypes_second) archetype_purity_second = archetype_counts_second / np.sum(archetype_counts_second) # Calculate overall metrics return { "archetype_purity_first": archetype_purity_first, "archetype_purity_second": archetype_purity_second, "overall_purity_first": np.max(archetype_purity_first) if archetype_purity_first.size > 0 else 0, "overall_purity_second": np.max(archetype_purity_second) if archetype_purity_second.size > 0 else 0, "purity_std_first": np.std(archetype_purity_first), "purity_std_second": np.std(archetype_purity_second), }
[docs] def weight_diversity(self) -> dict: """ Calculate diversity metrics for archetype weights. Returns: Dictionary of diversity metrics """ if self.model.alpha is None or self.model.gamma is None: raise ValueError("Model must be fitted before calculating weight diversity") # Calculate entropy for first set weights entropies_first = -np.sum(self.model.alpha * np.log2(self.model.alpha + 1e-10), axis=1) max_entropy_first = np.log2(self.model.alpha.shape[1]) # Add check to prevent division by zero if max_entropy_first > 0: normalized_entropies_first = entropies_first / max_entropy_first else: normalized_entropies_first = np.zeros_like(entropies_first) # Calculate entropy for second set weights entropies_second = -np.sum(self.model.gamma * np.log2(self.model.gamma + 1e-10), axis=0) max_entropy_second = np.log2(self.model.gamma.shape[0]) # Add check to prevent division by zero if max_entropy_second > 0: normalized_entropies_second = entropies_second / max_entropy_second else: normalized_entropies_second = np.zeros_like(entropies_second) # Calculate metrics metrics = { "mean_entropy_first": np.mean(entropies_first), "entropy_std_first": np.std(entropies_first), "max_entropy_first": np.max(entropies_first), "mean_normalized_entropy_first": np.mean(normalized_entropies_first), "mean_entropy_second": np.mean(entropies_second), "entropy_std_second": np.std(entropies_second), "max_entropy_second": np.max(entropies_second), "mean_normalized_entropy_second": np.mean(normalized_entropies_second), } return metrics
[docs] def comprehensive_evaluation(self, X: np.ndarray) -> dict: """ Perform a comprehensive evaluation of the model. Args: X: Data matrix Returns: Dictionary of evaluation metrics """ # Calculate reconstruction metrics reconstruction_metrics = { "frobenius": self.reconstruction_error(X, metric="frobenius"), "mse": self.reconstruction_error(X, metric="mse"), "mae": self.reconstruction_error(X, metric="mae"), "relative": self.reconstruction_error(X, metric="relative"), "explained_variance": self.explained_variance(X), } # Get other metrics separation_metrics = self.archetype_separation() purity_metrics = self.dominant_archetype_purity() diversity_metrics = self.weight_diversity() # Combine all metrics results = { "reconstruction": reconstruction_metrics, "separation": separation_metrics, "purity": purity_metrics, "diversity": diversity_metrics, } # Calculate other metrics if "clustering" not in results: results["clustering"] = {"silhouette": np.nan, "davies_bouldin": np.nan} if "convex_hull" not in results: results["convex_hull"] = { "volume": 0.0, "dimensionality": 0, "is_degenerate": True, "volume_ratio": 0.0, } return results
[docs] def print_evaluation_report(self, X: np.ndarray) -> None: """ Print a comprehensive evaluation report. Args: X: Original data matrix """ results = self.comprehensive_evaluation(X) print("\n" + "=" * 50) print( f"ARCHETYPAL ANALYSIS EVALUATION ({self.n_archetypes_first} archetypes, {self.n_archetypes_second} archetypes)" ) print("=" * 50) print("\n1. RECONSTRUCTION METRICS:") print(f" - Reconstruction Error: {results['reconstruction']['relative']:.4f}") print(f" - Explained Variance: {results['reconstruction']['explained_variance']:.4f}") print("\n2. ARCHETYPE SEPARATION:") print(f" - Minimum Distance (First Set): {results['separation']['min_distance_first']:.4f}") print(f" - Maximum Distance (First Set): {results['separation']['max_distance_first']:.4f}") print(f" - Mean Distance (First Set): {results['separation']['mean_distance_first']:.4f}") print(f" - Minimum Distance (Second Set): {results['separation']['min_distance_second']:.4f}") print(f" - Maximum Distance (Second Set): {results['separation']['max_distance_second']:.4f}") print(f" - Mean Distance (Second Set): {results['separation']['mean_distance_second']:.4f}") print("\n3. DOMINANT ARCHETYPE PURITY:") print(f" - Overall Purity (First Set): {results['purity']['overall_purity_first']:.4f}") print(f" - Overall Purity (Second Set): {results['purity']['overall_purity_second']:.4f}") print(" - Per-Archetype Purity (First Set):") for archetype, purity in results["purity"]["archetype_purity_first"].items(): print(f" - {archetype}: {purity:.4f}") print(" - Per-Archetype Purity (Second Set):") for archetype, purity in results["purity"]["archetype_purity_second"].items(): print(f" - {archetype}: {purity:.4f}") print("\n4. CLUSTERING METRICS:") if not np.isnan(results["clustering"]["silhouette"]): print(f" - Silhouette Score: {results['clustering']['silhouette']:.4f}") print(f" - Davies-Bouldin Index: {results['clustering']['davies_bouldin']:.4f}") else: print(" - Clustering metrics not available (insufficient data)") print("\n5. WEIGHT DIVERSITY:") print(f" - Mean Normalized Entropy (First Set): {results['diversity']['mean_normalized_entropy_first']:.4f}") print( f" - Mean Normalized Entropy (Second Set): {results['diversity']['mean_normalized_entropy_second']:.4f}" ) print("\n6. CONVEX HULL METRICS:") hull_metrics = results["convex_hull"] print(f" - Volume/Area (First Set): {hull_metrics['volume']:.6f}") print(f" - Volume/Area (Second Set): {hull_metrics['volume']:.6f}") if hull_metrics.get("volume_ratio") is not None: print(f" - Volume Ratio (vs Data) (First Set): {hull_metrics['volume_ratio']:.4f}") print(f" - Volume Ratio (vs Data) (Second Set): {hull_metrics['volume_ratio']:.4f}") print(f" - Dimensionality (First Set): {hull_metrics['dimensionality']}") print(f" - Dimensionality (Second Set): {hull_metrics['dimensionality']}") print(f" - Is Degenerate (First Set): {hull_metrics['is_degenerate']}") print(f" - Is Degenerate (Second Set): {hull_metrics['is_degenerate']}") if "error" in hull_metrics: print(f" - Error: {hull_metrics['error']}") print("\n" + "=" * 50)
[docs] def print_summary(self, results: dict): """Print a summary of the evaluation results. Args: results: Dictionary of evaluation results """ print("\n=== Biarchetypal Analysis Evaluation Summary ===") print("\n--- Separation Metrics ---") print(f"Mean Distance (First Set): {results['separation']['mean_distance_first']:.4f}") print(f"Mean Distance (Second Set): {results['separation']['mean_distance_second']:.4f}") # print(f"Mean Cross-Set Distance: {results['separation']['mean_cross_distance']:.4f}") print("\n--- Purity Metrics ---") print(f"Overall Purity (First Set): {results['purity']['overall_purity_first']:.4f}") print(f"Overall Purity (Second Set): {results['purity']['overall_purity_second']:.4f}") print("\n--- Weight Diversity Metrics ---") print(f"Mean Normalized Entropy (First Set): {results['diversity']['mean_normalized_entropy_first']:.4f}") print(f"Mean Normalized Entropy (Second Set): {results['diversity']['mean_normalized_entropy_second']:.4f}")