Source code for archetypax.tools.visualization

"""Advanced visualization tools for extracting insights from archetypal models.

This module provides specialized visualization capabilities that transform abstract
archetypal representations into intuitive visual insights. These visualizations
bridge the gap between mathematical models and human understanding by:

1. Revealing geometric relationships between data points and archetypes
2. Exposing patterns in feature utilization across different archetypes
3. Demonstrating reconstruction quality and model performance
4. Enabling exploration of relationships in both standard and biarchetypal space

These capabilities are essential for model interpretation, result communication,
and extracting actionable insights from archetypal analysis.
"""

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from ..models.base import ArchetypalAnalysis
from ..models.biarchetypes import BiarchetypalAnalysis


[docs] class ArchetypalAnalysisVisualizer: """Comprehensive visualization suite for archetypal analysis insights. This class provides specialized visualization methods that transform abstract archetypal models into intuitive visual representations. Rather than just plotting data, these methods reveal the underlying structures and relationships discovered by archetypal analysis, enabling: - Interpretation of archetype meaning and significance - Assessment of model quality and reconstruction fidelity - Communication of results to technical and non-technical audiences - Discovery of patterns in high-dimensional archetypal space These visualizations bridge the critical gap between mathematical models and human understanding, making archetypal analysis results accessible and actionable. """
[docs] @staticmethod def plot_loss(model: ArchetypalAnalysis) -> None: """Visualize convergence behavior through loss trajectory analysis. This diagnostic visualization reveals the optimization dynamics of the model by tracking loss values across iterations. It provides critical insights into: - Convergence speed and stability - Potential issues with learning rates or initialization - Evidence of premature convergence or local minima traps - Effectiveness of early stopping criteria Understanding these dynamics is essential for hyperparameter tuning, model validation, and diagnosing unexpected results. Args: model: Fitted ArchetypalAnalysis model with loss history """ loss_history = model.get_loss_history() if not loss_history: print("No loss history to plot") return plt.figure(figsize=(10, 6)) plt.plot(loss_history) plt.xlabel("Iteration") plt.ylabel("Loss") plt.title("Archetypal Analysis Loss History") plt.grid(True) plt.show()
[docs] @staticmethod def plot_archetypes_2d(model: ArchetypalAnalysis, X: np.ndarray, feature_names: list[str] | None = None) -> None: """Reveal geometric relationships between data and archetypes in 2D space. This visualization exposes the fundamental geometrical interpretation of archetypal analysis by showing how archetypes position themselves at the extremes of the data distribution and form a convex hull. The plot reveals: - Position of archetypes relative to the data cloud - Dominance relationships between data points and archetypes - The convex hull structure formed by the archetypes - Feature-specific patterns that define each archetype This representation is particularly valuable for initial model validation, intuitive explanation of what archetypes represent, and identification of outliers or unexpected patterns. Args: model: Fitted ArchetypalAnalysis model with discovered archetypes X: Original data matrix in 2D space feature_names: Optional feature names for meaningful axis labels """ from scipy.spatial import ConvexHull if model.archetypes is None: raise ValueError("Model must be fitted before plotting") if model.weights is None: raise ValueError("Model must be fitted before plotting") if X.shape[1] != 2: raise ValueError("This plotting function is only for 2D data") weights: np.ndarray = model.weights plt.figure(figsize=(10, 8)) plt.scatter(X[:, 0], X[:, 1], alpha=0.5, label="Data") plt.scatter( model.archetypes[:, 0], model.archetypes[:, 1], c="red", s=100, marker="*", label="Archetypes", ) # Add arrows from data points to their dominant archetypes for i in range(min(100, len(X))): # Show max 100 arrows for performance # Find the archetype with the highest weight if weights is not None and model.archetypes is not None: max_idx = np.argmax(weights[i]) if weights[i, max_idx] > 0.5: # Only draw if weight is significant plt.arrow( X[i, 0], X[i, 1], model.archetypes[max_idx, 0] - X[i, 0], model.archetypes[max_idx, 1] - X[i, 1], head_width=0.01, head_length=0.02, alpha=0.1, color="grey", ) # Show convex hull if len(model.archetypes) >= 3: try: hull = ConvexHull(model.archetypes) for simplex in hull.simplices: plt.plot(model.archetypes[simplex, 0], model.archetypes[simplex, 1], "r-") except Exception as e: print(f"Could not plot convex hull: {e!s}") # Add feature names if provided if feature_names is not None and len(feature_names) >= 2: plt.xlabel(feature_names[0]) plt.ylabel(feature_names[1]) else: plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.legend() plt.title("Data and Archetypes") plt.grid(True) plt.show()
[docs] @staticmethod def plot_reconstruction_comparison(model: ArchetypalAnalysis, X: np.ndarray) -> None: """Assess model fidelity through side-by-side reconstruction comparison. This visualization provides a direct assessment of how well the archetypal model captures the underlying data structure by comparing original and reconstructed data points. This comparison reveals: - Overall reconstruction quality and information preservation - Specific regions where the model performs well or poorly - Distortion patterns introduced by dimensionality reduction - Evidence of potential overfitting or underfitting This assessment is critical for validating model quality, determining an appropriate number of archetypes, and communicating the tradeoff between interpretability and accuracy. Args: model: Fitted ArchetypalAnalysis model for reconstruction X: Original data matrix to be reconstructed """ if model.archetypes is None: raise ValueError("Model must be fitted before plotting") if X.shape[1] != 2: raise ValueError("This plotting function is only for 2D data") # Reconstruct data X_reconstructed = model.reconstruct() # Calculate reconstruction error error = np.linalg.norm(X - X_reconstructed, ord="fro") print(f"Reconstruction error: {error:.6f}") # Plot reconstruction plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) plt.scatter(X[:, 0], X[:, 1], alpha=0.7, label="Original") plt.title("Original Data") plt.grid(True) plt.subplot(1, 2, 2) plt.scatter( X_reconstructed[:, 0], X_reconstructed[:, 1], alpha=0.7, label="Reconstructed", ) plt.scatter( model.archetypes[:, 0], model.archetypes[:, 1], c="red", s=100, marker="*", label="Archetypes", ) plt.title("Reconstructed Data") plt.grid(True) plt.tight_layout() plt.show()
[docs] @staticmethod def plot_membership_weights(model: ArchetypalAnalysis, n_samples: int | None = None) -> None: """Visualize how samples relate to archetypes through weight distribution patterns. This heatmap visualization reveals the fundamental composition patterns in the data by showing how each sample leverages different archetypes. The visualization exposes: - Dominant archetypes for each sample - Patterns of archetype co-utilization - Samples with similar composition profiles - Evidence of archetype redundancy or specialization These insights are valuable for clustering analysis, identifying representative samples, detecting subpopulations, and understanding how archetypes interact to represent the data. Args: model: Fitted ArchetypalAnalysis model with weights n_samples: Optional number of samples to visualize (default: all) """ if model.archetypes is None or model.weights is None: raise ValueError("Model must be fitted before plotting membership weights") weights = model.weights if n_samples is not None: # Select a subset of samples if specified n_samples = min(n_samples, weights.shape[0]) # Sort samples by their max weight for better visualization max_weight_idx = np.argmax(weights, axis=1) sorted_indices = np.argsort(max_weight_idx) sample_indices = sorted_indices[:n_samples] weights_subset = weights[sample_indices] else: # Use all samples, but sort them for better visualization max_weight_idx = np.argmax(weights, axis=1) sorted_indices = np.argsort(max_weight_idx) weights_subset = weights[sorted_indices] n_samples = weights.shape[0] plt.figure(figsize=(12, 8)) # Create a heatmap of the membership weights ax = sns.heatmap( weights_subset, cmap="viridis", annot=True, vmin=0, vmax=1, yticklabels=False, ) ax.set_xlabel("Archetypes") ax.set_ylabel("Samples") ax.set_title(f"Membership Weights for {n_samples} Samples") # Add archetype indices as x-tick labels plt.xticks( np.arange(model.n_archetypes) + 0.5, labels=[f"A{i}" for i in range(model.n_archetypes)], ) plt.tight_layout() plt.show()
[docs] @staticmethod def plot_archetype_profiles(model: ArchetypalAnalysis, feature_names: list[str] | None = None) -> None: """ Plot feature profiles of each archetype. Args: model: Fitted ArchetypalAnalysis model feature_names: Optional list of feature names for axis labels """ if model.archetypes is None: raise ValueError("Model must be fitted before plotting archetype profiles") n_archetypes, n_features = model.archetypes.shape # Create default feature names if not provided if feature_names is None: feature_names = [f"Feature {i}" for i in range(n_features)] # Prepare feature indices for the x-axis x = np.arange(n_features) plt.figure(figsize=(12, 8)) # Plot each archetype as a line for i in range(n_archetypes): plt.plot(x, model.archetypes[i], marker="o", label=f"Archetype {i}") plt.xlabel("Features") plt.ylabel("Feature Value") plt.title("Archetype Feature Profiles") plt.grid(True, alpha=0.3) plt.legend() # Set feature names as x-tick labels if not too many if n_features <= 20: plt.xticks(x, feature_names, rotation=45, ha="right") plt.tight_layout() plt.show()
[docs] @staticmethod def plot_archetype_distribution(model: ArchetypalAnalysis) -> None: """ Plot the distribution of dominant archetypes across samples. Args: model: Fitted ArchetypalAnalysis model """ if model.weights is None: raise ValueError("Model must be fitted before plotting archetype distribution") # Find the dominant archetype for each sample dominant_archetypes = np.argmax(model.weights, axis=1) # Count occurrences of each archetype as dominant unique, counts = np.unique(dominant_archetypes, return_counts=True) plt.figure(figsize=(10, 6)) # Create a bar plot bars = plt.bar( range(model.n_archetypes), [counts[list(unique).index(i)] if i in unique else 0 for i in range(model.n_archetypes)], color="skyblue", alpha=0.7, ) # Add labels and percentages total_samples = len(dominant_archetypes) for bar in bars: height = bar.get_height() percentage = 100 * height / total_samples plt.text( bar.get_x() + bar.get_width() / 2.0, height + 0.1, f"{height} ({percentage:.1f}%)", ha="center", va="bottom", rotation=0, ) plt.xlabel("Archetype") plt.ylabel("Number of Samples") plt.title("Distribution of Dominant Archetypes") plt.xticks(range(model.n_archetypes), [f"A{i}" for i in range(model.n_archetypes)]) plt.grid(True, axis="y", alpha=0.3) plt.tight_layout() plt.show()
[docs] @staticmethod def plot_simplex_2d(model: ArchetypalAnalysis, n_samples: int | None = 500) -> None: """ Plot samples in 2D simplex space (only works for 3 archetypes). Args: model: Fitted ArchetypalAnalysis model n_samples: Number of samples to plot (default: 500) """ if model.archetypes is None or model.weights is None: raise ValueError("Model must be fitted before plotting simplex") if model.n_archetypes != 3: raise ValueError("Simplex plot only works for exactly 3 archetypes") # Select a subset of samples if specified weights = model.weights if n_samples is not None and n_samples < weights.shape[0]: indices = np.random.choice(weights.shape[0], n_samples, replace=False) weights_subset = weights[indices] else: weights_subset = weights # Convert barycentric coordinates to 2D for visualization # For a 3-simplex, we can use an equilateral triangle # Where each vertex represents an archetype sqrt3_2 = np.sqrt(3) / 2 triangle_vertices = np.array([ [0, 0], # Archetype 0 at origin [1, 0], # Archetype 1 at (1,0) [0.5, sqrt3_2], # Archetype 2 at (0.5, sqrt(3)/2) ]) # Transform weights to 2D coordinates points_2d = np.dot(weights_subset, triangle_vertices) # Create a colormap based on which archetype has the highest weight dominant_archetypes = np.argmax(weights_subset, axis=1) plt.figure(figsize=(10, 8)) # Plot the simplex boundaries plt.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-") # Add vertex labels plt.text(-0.05, -0.05, "Archetype 0", ha="right") plt.text(1.05, -0.05, "Archetype 1", ha="left") plt.text(0.5, sqrt3_2 + 0.05, "Archetype 2", ha="center") # Plot points colored by dominant archetype scatter = plt.scatter( points_2d[:, 0], points_2d[:, 1], c=dominant_archetypes, alpha=0.6, cmap="viridis", ) # Add a color legend legend1 = plt.legend(*scatter.legend_elements(), title="Dominant Archetype") plt.gca().add_artist(legend1) # Add grid lines for the simplex for i in range(1, 10): p = i / 10 # Line parallel to the bottom edge plt.plot( [p * 0.5, p + (1 - p) * 0.5], [p * sqrt3_2, (1 - p) * 0], "gray", alpha=0.3, ) # Line parallel to the left edge plt.plot([0, p * 0.5], [p * 0, p * sqrt3_2], "gray", alpha=0.3) # Line parallel to the right edge plt.plot( [p * 1, 0.5 + (1 - p) * 0.5], [p * 0, (1 - p) * sqrt3_2], "gray", alpha=0.3, ) plt.axis("equal") plt.title("Samples in Simplex Space") plt.axis("off") plt.tight_layout() plt.show()
[docs] class BiarchetypalAnalysisVisualizer: """Visualization utilities for Biarchetypal Analysis."""
[docs] @staticmethod def plot_dual_archetypes_2d( model: BiarchetypalAnalysis, X: np.ndarray, feature_names: list[str] | None = None ) -> None: """ Plot data and both sets of archetypes in 2D. Args: model: Fitted BiarchetypalAnalysis model X: Original data feature_names: Optional feature names for axis labels """ if X.shape[1] != 2: raise ValueError("This plotting function is only for 2D data") archetypes_first, archetypes_second = model.get_all_archetypes() weights_first, weights_second = model.get_all_weights() plt.figure(figsize=(12, 8)) # Plot data points plt.scatter(X[:, 0], X[:, 1], alpha=0.4, color="gray", label="Data") # Plot first set of archetypes plt.scatter( archetypes_first[:, 0], archetypes_first[:, 1], c="blue", s=150, marker="*", label=f"Archetypes Set 1 (n={model.n_row_archetypes})", ) # Plot second set of archetypes plt.scatter( archetypes_second[:, 0], archetypes_second[:, 1], c="red", s=150, marker="^", label=f"Archetypes Set 2 (n={model.n_col_archetypes})", ) # Connect points to their dominant archetypes in each set for i in range(min(50, X.shape[0])): # Limit to 50 arrows for visual clarity # Find dominant archetype in first set max_idx_first = np.argmax(weights_first[i]) if weights_first[i, max_idx_first] > 0.5: plt.arrow( X[i, 0], X[i, 1], archetypes_first[max_idx_first, 0] - X[i, 0], archetypes_first[max_idx_first, 1] - X[i, 1], head_width=0.01, head_length=0.02, alpha=0.2, color="blue", linestyle="--", ) # Find dominant archetype in second set # Handle different shapes of weights_second if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # If weights_second has shape (n_col_archetypes, n_features) weights_second_transposed = weights_second.T if i < weights_second_transposed.shape[0]: max_idx_second = np.argmax(weights_second_transposed[i]) if weights_second_transposed[i, max_idx_second] > 0.5: plt.arrow( X[i, 0], X[i, 1], archetypes_second[max_idx_second, 0] - X[i, 0], archetypes_second[max_idx_second, 1] - X[i, 1], head_width=0.01, head_length=0.02, alpha=0.2, color="red", linestyle=":", ) elif len(weights_second.shape) == 2 and i < weights_second.shape[0]: # For standard shape max_idx_second = np.argmax(weights_second[i]) if weights_second[i, max_idx_second] > 0.5: plt.arrow( X[i, 0], X[i, 1], archetypes_second[max_idx_second, 0] - X[i, 0], archetypes_second[max_idx_second, 1] - X[i, 1], head_width=0.01, head_length=0.02, alpha=0.2, color="red", linestyle=":", ) # Add feature names if provided if feature_names and len(feature_names) >= 2: plt.xlabel(feature_names[0]) plt.ylabel(feature_names[1]) else: plt.xlabel("Feature 1") plt.ylabel("Feature 2") plt.title("Data and Dual Archetype Sets") plt.legend() plt.grid(True, alpha=0.3) plt.show()
[docs] @staticmethod def plot_biarchetypal_reconstruction(model: BiarchetypalAnalysis, X: np.ndarray) -> None: """ Plot original data vs. reconstructions from each archetype set and combined. Args: model: Fitted BiarchetypalAnalysis model X: Original data matrix """ if X.shape[1] != 2: raise ValueError("This plotting function is only for 2D data") # Get weights for both archetype sets weights_first, weights_second = model.get_all_weights() archetypes_first, archetypes_second = model.get_all_archetypes() # Create reconstructions X_recon_first = np.matmul(weights_first, archetypes_first) # Handle different shapes of weights_second if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # If weights_second has shape (n_col_archetypes, n_features) weights_second_transposed = weights_second.T if weights_second_transposed.shape[0] == weights_first.shape[0]: X_recon_second = np.matmul(weights_second_transposed, archetypes_second) else: # Create dummy data with matching shape if shapes don't match X_recon_second = np.zeros_like(X_recon_first) else: # For standard shape X_recon_second = np.matmul(weights_second, archetypes_second) # Check and adjust shape of X_recon_second if necessary if X_recon_second.shape != X_recon_first.shape: # Create dummy data with matching shape if shapes don't match X_recon_second_temp = X_recon_second.copy() X_recon_second = np.zeros_like(X_recon_first) # Use original data where possible min_rows = min(X_recon_second_temp.shape[0], X_recon_first.shape[0]) min_cols = min(X_recon_second_temp.shape[1], X_recon_first.shape[1]) X_recon_second[:min_rows, :min_cols] = X_recon_second_temp[:min_rows, :min_cols] # Create combined reconstruction using mixture weight if hasattr(model, "mixture_weight"): X_recon_combined = model.mixture_weight * X_recon_first + (1 - model.mixture_weight) * X_recon_second else: # Use equal weights if mixture_weight doesn't exist X_recon_combined = 0.5 * X_recon_first + 0.5 * X_recon_second # Calculate reconstruction errors error_first = np.linalg.norm(X - X_recon_first, ord="fro") error_second = np.linalg.norm(X - X_recon_second, ord="fro") error_combined = np.linalg.norm(X - X_recon_combined, ord="fro") # Create plot with subplots _, axes = plt.subplots(2, 2, figsize=(15, 12)) # Original data axes[0, 0].scatter(X[:, 0], X[:, 1], alpha=0.7, label="Original") axes[0, 0].set_title("Original Data") axes[0, 0].grid(True, alpha=0.3) # First archetype set reconstruction axes[0, 1].scatter(X_recon_first[:, 0], X_recon_first[:, 1], alpha=0.7, color="blue") axes[0, 1].scatter( archetypes_first[:, 0], archetypes_first[:, 1], c="blue", s=100, marker="*", label="Archetypes Set 1", ) axes[0, 1].set_title(f"First Set Reconstruction\nError: {error_first:.4f}") axes[0, 1].grid(True, alpha=0.3) # Second archetype set reconstruction axes[1, 0].scatter(X_recon_second[:, 0], X_recon_second[:, 1], alpha=0.7, color="red") axes[1, 0].scatter( archetypes_second[:, 0], archetypes_second[:, 1], c="red", s=100, marker="^", label="Archetypes Set 2", ) axes[1, 0].set_title(f"Second Set Reconstruction\nError: {error_second:.4f}") axes[1, 0].grid(True, alpha=0.3) # Combined reconstruction axes[1, 1].scatter(X_recon_combined[:, 0], X_recon_combined[:, 1], alpha=0.7, color="purple") axes[1, 1].scatter( archetypes_first[:, 0], archetypes_first[:, 1], c="blue", s=100, marker="*", label="Archetypes Set 1", ) axes[1, 1].scatter( archetypes_second[:, 0], archetypes_second[:, 1], c="red", s=100, marker="^", label="Archetypes Set 2", ) # Check if mixture_weight exists mixture_weight = model.mixture_weight if hasattr(model, "mixture_weight") else 0.5 axes[1, 1].set_title(f"Combined Reconstruction (w={mixture_weight:.2f})\nError: {error_combined:.4f}") axes[1, 1].grid(True, alpha=0.3) # Add legend to the last plot axes[1, 1].legend() plt.tight_layout() plt.show()
[docs] @staticmethod def plot_dual_membership_heatmap(model: BiarchetypalAnalysis, n_samples: int = 50) -> None: """ Plot heatmap of membership weights for both sets of archetypes. Args: model: Fitted BiarchetypalAnalysis model n_samples: Number of samples to visualize """ weights_first, weights_second = model.get_all_weights() # Select a subset of samples n_samples = min(n_samples, weights_first.shape[0]) # Sort samples by their dominant archetype in first set only max_weight_first_idx = np.argmax(weights_first, axis=1) sorted_indices = np.argsort(max_weight_first_idx)[:n_samples] # Get the subsets for plotting weights_first_subset = weights_first[sorted_indices] # Create figure with two subplots _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8)) # Plot first set heatmap sns.heatmap( weights_first_subset, ax=ax1, cmap="viridis", annot=True, fmt=".2f", xticklabels=[f"A1_{i}" for i in range(model.n_row_archetypes)], yticklabels=False, ) ax1.set_title(f"First Set Membership Weights (n={model.n_row_archetypes})") ax1.set_xlabel("Archetypes (First Set)") ax1.set_ylabel("Samples") # Plot second set heatmap if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # Use transposed weights_second weights_second_transposed = weights_second.T if weights_second_transposed.shape[0] == weights_first.shape[0]: weights_second_subset = weights_second_transposed[sorted_indices] else: # Create matrix of ones if shapes don't match weights_second_subset = np.ones((len(sorted_indices), model.n_col_archetypes)) else: # Create matrix of ones for n_col_archetypes=1 case weights_second_subset = np.ones((len(sorted_indices), 1)) sns.heatmap( weights_second_subset, ax=ax2, cmap="viridis", annot=True, fmt=".2f", xticklabels=[f"A2_{i}" for i in range(model.n_col_archetypes)], yticklabels=False, ) ax2.set_title(f"Second Set Membership Weights (n={model.n_col_archetypes})") ax2.set_xlabel("Archetypes (Second Set)") plt.tight_layout() plt.show()
[docs] @staticmethod def plot_mixture_effect(model: BiarchetypalAnalysis, X: np.ndarray, mixture_steps: int = 5) -> None: """ Plot the effect of different mixture weights between the two archetype sets. Args: model: Fitted BiarchetypalAnalysis model X: Original data matrix mixture_steps: Number of different mixture weights to try """ if X.shape[1] != 2: raise ValueError("This plotting function is only for 2D data") # Get weights and archetypes for both sets weights_first, weights_second = model.get_all_weights() archetypes_first, archetypes_second = model.get_all_archetypes() # Create reconstructions for first and second sets X_recon_first = np.matmul(weights_first, archetypes_first) # Handle different shapes of weights_second if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # If weights_second has shape (n_col_archetypes, n_features) weights_second_transposed = weights_second.T if weights_second_transposed.shape[0] == weights_first.shape[0]: X_recon_second = np.matmul(weights_second_transposed, archetypes_second) else: # Create dummy data with matching shape if shapes don't match X_recon_second = np.zeros_like(X_recon_first) # Display warning print("Warning: Shape mismatch in weights_second. Using dummy data for second reconstruction.") else: # For standard shape X_recon_second = np.matmul(weights_second, archetypes_second) # Check and adjust shape of X_recon_second if necessary if X_recon_second.shape != X_recon_first.shape: # Create dummy data with matching shape if shapes don't match X_recon_second_temp = X_recon_second.copy() X_recon_second = np.zeros_like(X_recon_first) # Use original data where possible min_rows = min(X_recon_second_temp.shape[0], X_recon_first.shape[0]) min_cols = min(X_recon_second_temp.shape[1], X_recon_first.shape[1]) X_recon_second[:min_rows, :min_cols] = X_recon_second_temp[:min_rows, :min_cols] # Display warning message print( f"Warning: Shape mismatch between reconstructions. X_recon_first: {X_recon_first.shape}, X_recon_second: {X_recon_second_temp.shape}" ) # Create figure with subplots n_rows = (mixture_steps + 2) // 3 # Ceiling division _, axes = plt.subplots(n_rows, min(3, mixture_steps), figsize=(15, 4 * n_rows)) # Flatten axes if necessary if mixture_steps > 3: axes = axes.flatten() elif mixture_steps == 1: axes = [axes] # Convert to list for single subplot case # Original data for reference X_original = X.copy() # Try different mixture weights for i in range(mixture_steps): # Calculate mixture weight mix_weight = i / (mixture_steps - 1) if mixture_steps > 1 else 0.5 # Create mixed reconstruction X_mixed = mix_weight * X_recon_first + (1 - mix_weight) * X_recon_second # Calculate error error = np.linalg.norm(X_original - X_mixed, ord="fro") # Plot ax = axes[i] if mixture_steps > 1 else axes ax.scatter(X_original[:, 0], X_original[:, 1], alpha=0.2, color="gray", label="Original") ax.scatter(X_mixed[:, 0], X_mixed[:, 1], alpha=0.7, color="purple", label="Reconstructed") ax.scatter( archetypes_first[:, 0], archetypes_first[:, 1], c="blue", s=80, marker="*", label="Set 1", ) ax.scatter( archetypes_second[:, 0], archetypes_second[:, 1], c="red", s=80, marker="^", label="Set 2", ) ax.set_title(f"w={mix_weight:.2f}, Error={error:.4f}") # Only add legend to the first plot if i == 0: ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.show()
[docs] @staticmethod def plot_dual_simplex_2d(model: BiarchetypalAnalysis, n_samples: int = 200) -> None: """ Plot samples in separate 2D simplex spaces for each archetype set (only works for 3 archetypes per set). Args: model: Fitted BiarchetypalAnalysis model n_samples: Number of samples to plot """ # Relaxed condition: at least one of the archetype sets must have exactly 3 archetypes if model.n_row_archetypes != 3 and model.n_col_archetypes != 3: raise ValueError("This simplex plot requires at least one set to have exactly 3 archetypes") # Get weights for both sets weights_first, weights_second = model.get_all_weights() # Select a subset of samples if needed if n_samples < weights_first.shape[0]: indices = np.random.choice(weights_first.shape[0], n_samples, replace=False) weights_first_subset = weights_first[indices] # Handle different shapes of weights_second if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # If weights_second has shape (n_col_archetypes, n_features) weights_second_transposed = weights_second.T if weights_second_transposed.shape[0] == weights_first.shape[0]: weights_second_subset = weights_second_transposed[indices] else: # Create dummy data if shapes don't match weights_second_subset = np.ones((len(indices), model.n_col_archetypes)) / model.n_col_archetypes else: # For standard shape weights_second_subset = weights_second[indices] else: weights_first_subset = weights_first # Handle different shapes of weights_second if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes: # If weights_second has shape (n_col_archetypes, n_features) weights_second_transposed = weights_second.T if weights_second_transposed.shape[0] == weights_first.shape[0]: weights_second_subset = weights_second_transposed else: # Create dummy data if shapes don't match weights_second_subset = ( np.ones((weights_first.shape[0], model.n_col_archetypes)) / model.n_col_archetypes ) else: # For standard shape weights_second_subset = weights_second # Handle case where first archetype set doesn't have exactly 3 archetypes if model.n_row_archetypes != 3: # Create dummy data weights_first_subset = np.ones((weights_first_subset.shape[0], 3)) / 3 # Handle case where second archetype set doesn't have exactly 3 archetypes if model.n_col_archetypes != 3: # Create dummy data weights_second_subset = np.ones((weights_second_subset.shape[0], 3)) / 3 # Set up the figure with two subplots _, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7)) # Convert barycentric coordinates to 2D for visualization sqrt3_2 = np.sqrt(3) / 2 triangle_vertices = np.array([ [0, 0], # Archetype 0 at origin [1, 0], # Archetype 1 at (1,0) [0.5, sqrt3_2], # Archetype 2 at (0.5, sqrt(3)/2) ]) # Transform weights to 2D coordinates for both sets points_2d_first = np.dot(weights_first_subset, triangle_vertices) points_2d_second = np.dot(weights_second_subset, triangle_vertices) # Create dominant archetype colormaps for each set dominant_archetypes_first = np.argmax(weights_first_subset, axis=1) dominant_archetypes_second = np.argmax(weights_second_subset, axis=1) # Plot first set simplex ax1.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-") ax1.scatter( points_2d_first[:, 0], points_2d_first[:, 1], c=dominant_archetypes_first, alpha=0.6, cmap="Blues", ) ax1.text(-0.05, -0.05, "A1_0", ha="right", color="blue") ax1.text(1.05, -0.05, "A1_1", ha="left", color="blue") ax1.text(0.5, sqrt3_2 + 0.05, "A1_2", ha="center", color="blue") ax1.set_title("First Archetype Set Simplex" + (" (Dummy)" if model.n_row_archetypes != 3 else "")) ax1.axis("equal") ax1.axis("off") # Plot second set simplex ax2.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-") ax2.scatter( points_2d_second[:, 0], points_2d_second[:, 1], c=dominant_archetypes_second, alpha=0.6, cmap="Reds", ) ax2.text(-0.05, -0.05, "A2_0", ha="right", color="red") ax2.text(1.05, -0.05, "A2_1", ha="left", color="red") ax2.text(0.5, sqrt3_2 + 0.05, "A2_2", ha="center", color="red") ax2.set_title("Second Archetype Set Simplex" + (" (Dummy)" if model.n_col_archetypes != 3 else "")) ax2.axis("equal") ax2.axis("off") # Add grid lines for the simplex for ax in [ax1, ax2]: for i in range(1, 10): p = i / 10 # Line parallel to the bottom edge ax.plot( [p * 0.5, p + (1 - p) * 0.5], [p * sqrt3_2, (1 - p) * 0], "gray", alpha=0.3, ) # Line parallel to the left edge ax.plot([0, p * 0.5], [p * 0, p * sqrt3_2], "gray", alpha=0.3) # Line parallel to the right edge ax.plot( [p * 1, 0.5 + (1 - p) * 0.5], [p * 0, (1 - p) * sqrt3_2], "gray", alpha=0.3, ) plt.tight_layout() plt.show()