"""Visualization utilities for Archetypal Analysis."""
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from ..models.base import ArchetypalAnalysis
from ..models.biarchetypes import BiarchetypalAnalysis
[docs]
class ArchetypalAnalysisVisualizer:
"""Visualization utilities for Archetypal Analysis."""
[docs]
@staticmethod
def plot_loss(model: ArchetypalAnalysis) -> None:
"""
Plot the loss history from training.
Args:
model: Fitted ArchetypalAnalysis model
"""
loss_history = model.get_loss_history()
if not loss_history:
print("No loss history to plot")
return
plt.figure(figsize=(10, 6))
plt.plot(loss_history)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Archetypal Analysis Loss History")
plt.grid(True)
plt.show()
[docs]
@staticmethod
def plot_archetypes_2d(
model: ArchetypalAnalysis, X: np.ndarray, feature_names: list[str] | None = None
) -> None:
"""
Plot data and archetypes in 2D.
Args:
model: Fitted ArchetypalAnalysis model
X: Original data
feature_names: Optional feature names for axis labels
"""
from scipy.spatial import ConvexHull
if model.archetypes is None:
raise ValueError("Model must be fitted before plotting")
if model.weights is None:
raise ValueError("Model must be fitted before plotting")
if X.shape[1] != 2:
raise ValueError("This plotting function is only for 2D data")
weights: np.ndarray = model.weights
plt.figure(figsize=(10, 8))
plt.scatter(X[:, 0], X[:, 1], alpha=0.5, label="Data")
plt.scatter(
model.archetypes[:, 0],
model.archetypes[:, 1],
c="red",
s=100,
marker="*",
label="Archetypes",
)
# Add arrows from data points to their dominant archetypes
for i in range(min(100, len(X))): # Show max 100 arrows for performance
# Find the archetype with the highest weight
if weights is not None and model.archetypes is not None:
max_idx = np.argmax(weights[i])
if weights[i, max_idx] > 0.5: # Only draw if weight is significant
plt.arrow(
X[i, 0],
X[i, 1],
model.archetypes[max_idx, 0] - X[i, 0],
model.archetypes[max_idx, 1] - X[i, 1],
head_width=0.01,
head_length=0.02,
alpha=0.1,
color="grey",
)
# Show convex hull
if len(model.archetypes) >= 3:
try:
hull = ConvexHull(model.archetypes)
for simplex in hull.simplices:
plt.plot(model.archetypes[simplex, 0], model.archetypes[simplex, 1], "r-")
except Exception as e:
print(f"Could not plot convex hull: {e!s}")
# Add feature names if provided
if feature_names is not None and len(feature_names) >= 2:
plt.xlabel(feature_names[0])
plt.ylabel(feature_names[1])
else:
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()
plt.title("Data and Archetypes")
plt.grid(True)
plt.show()
[docs]
@staticmethod
def plot_reconstruction_comparison(model: ArchetypalAnalysis, X: np.ndarray) -> None:
"""
Plot original vs reconstructed data.
Args:
model: Fitted ArchetypalAnalysis model
X: Original data matrix
"""
if model.archetypes is None:
raise ValueError("Model must be fitted before plotting")
if X.shape[1] != 2:
raise ValueError("This plotting function is only for 2D data")
# Reconstruct data
X_reconstructed = model.reconstruct()
# Calculate reconstruction error
error = np.linalg.norm(X - X_reconstructed, ord="fro")
print(f"Reconstruction error: {error:.6f}")
# Plot reconstruction
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], alpha=0.7, label="Original")
plt.title("Original Data")
plt.grid(True)
plt.subplot(1, 2, 2)
plt.scatter(
X_reconstructed[:, 0],
X_reconstructed[:, 1],
alpha=0.7,
label="Reconstructed",
)
plt.scatter(
model.archetypes[:, 0],
model.archetypes[:, 1],
c="red",
s=100,
marker="*",
label="Archetypes",
)
plt.title("Reconstructed Data")
plt.grid(True)
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_membership_weights(model: ArchetypalAnalysis, n_samples: int | None = None) -> None:
"""
Plot membership weights for samples.
Args:
model: Fitted ArchetypalAnalysis model
n_samples: Optional number of samples to visualize (default: all)
"""
if model.archetypes is None or model.weights is None:
raise ValueError("Model must be fitted before plotting membership weights")
weights = model.weights
if n_samples is not None:
# Select a subset of samples if specified
n_samples = min(n_samples, weights.shape[0])
# Sort samples by their max weight for better visualization
max_weight_idx = np.argmax(weights, axis=1)
sorted_indices = np.argsort(max_weight_idx)
sample_indices = sorted_indices[:n_samples]
weights_subset = weights[sample_indices]
else:
# Use all samples, but sort them for better visualization
max_weight_idx = np.argmax(weights, axis=1)
sorted_indices = np.argsort(max_weight_idx)
weights_subset = weights[sorted_indices]
n_samples = weights.shape[0]
plt.figure(figsize=(12, 8))
# Create a heatmap of the membership weights
ax = sns.heatmap(
weights_subset,
cmap="viridis",
annot=True,
vmin=0,
vmax=1,
yticklabels=False,
)
ax.set_xlabel("Archetypes")
ax.set_ylabel("Samples")
ax.set_title(f"Membership Weights for {n_samples} Samples")
# Add archetype indices as x-tick labels
plt.xticks(
np.arange(model.n_archetypes) + 0.5,
labels=[f"A{i}" for i in range(model.n_archetypes)],
)
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_archetype_profiles(
model: ArchetypalAnalysis, feature_names: list[str] | None = None
) -> None:
"""
Plot feature profiles of each archetype.
Args:
model: Fitted ArchetypalAnalysis model
feature_names: Optional list of feature names for axis labels
"""
if model.archetypes is None:
raise ValueError("Model must be fitted before plotting archetype profiles")
n_archetypes, n_features = model.archetypes.shape
# Create default feature names if not provided
if feature_names is None:
feature_names = [f"Feature {i}" for i in range(n_features)]
# Prepare feature indices for the x-axis
x = np.arange(n_features)
plt.figure(figsize=(12, 8))
# Plot each archetype as a line
for i in range(n_archetypes):
plt.plot(x, model.archetypes[i], marker="o", label=f"Archetype {i}")
plt.xlabel("Features")
plt.ylabel("Feature Value")
plt.title("Archetype Feature Profiles")
plt.grid(True, alpha=0.3)
plt.legend()
# Set feature names as x-tick labels if not too many
if n_features <= 20:
plt.xticks(x, feature_names, rotation=45, ha="right")
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_archetype_distribution(model: ArchetypalAnalysis) -> None:
"""
Plot the distribution of dominant archetypes across samples.
Args:
model: Fitted ArchetypalAnalysis model
"""
if model.weights is None:
raise ValueError("Model must be fitted before plotting archetype distribution")
# Find the dominant archetype for each sample
dominant_archetypes = np.argmax(model.weights, axis=1)
# Count occurrences of each archetype as dominant
unique, counts = np.unique(dominant_archetypes, return_counts=True)
plt.figure(figsize=(10, 6))
# Create a bar plot
bars = plt.bar(
range(model.n_archetypes),
[
counts[list(unique).index(i)] if i in unique else 0
for i in range(model.n_archetypes)
],
color="skyblue",
alpha=0.7,
)
# Add labels and percentages
total_samples = len(dominant_archetypes)
for bar in bars:
height = bar.get_height()
percentage = 100 * height / total_samples
plt.text(
bar.get_x() + bar.get_width() / 2.0,
height + 0.1,
f"{height} ({percentage:.1f}%)",
ha="center",
va="bottom",
rotation=0,
)
plt.xlabel("Archetype")
plt.ylabel("Number of Samples")
plt.title("Distribution of Dominant Archetypes")
plt.xticks(range(model.n_archetypes), [f"A{i}" for i in range(model.n_archetypes)])
plt.grid(True, axis="y", alpha=0.3)
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_simplex_2d(model: ArchetypalAnalysis, n_samples: int | None = 500) -> None:
"""
Plot samples in 2D simplex space (only works for 3 archetypes).
Args:
model: Fitted ArchetypalAnalysis model
n_samples: Number of samples to plot (default: 500)
"""
if model.archetypes is None or model.weights is None:
raise ValueError("Model must be fitted before plotting simplex")
if model.n_archetypes != 3:
raise ValueError("Simplex plot only works for exactly 3 archetypes")
# Select a subset of samples if specified
weights = model.weights
if n_samples is not None and n_samples < weights.shape[0]:
indices = np.random.choice(weights.shape[0], n_samples, replace=False)
weights_subset = weights[indices]
else:
weights_subset = weights
# Convert barycentric coordinates to 2D for visualization
# For a 3-simplex, we can use an equilateral triangle
# Where each vertex represents an archetype
sqrt3_2 = np.sqrt(3) / 2
triangle_vertices = np.array(
[
[0, 0], # Archetype 0 at origin
[1, 0], # Archetype 1 at (1,0)
[0.5, sqrt3_2], # Archetype 2 at (0.5, sqrt(3)/2)
]
)
# Transform weights to 2D coordinates
points_2d = np.dot(weights_subset, triangle_vertices)
# Create a colormap based on which archetype has the highest weight
dominant_archetypes = np.argmax(weights_subset, axis=1)
plt.figure(figsize=(10, 8))
# Plot the simplex boundaries
plt.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-")
# Add vertex labels
plt.text(-0.05, -0.05, "Archetype 0", ha="right")
plt.text(1.05, -0.05, "Archetype 1", ha="left")
plt.text(0.5, sqrt3_2 + 0.05, "Archetype 2", ha="center")
# Plot points colored by dominant archetype
scatter = plt.scatter(
points_2d[:, 0],
points_2d[:, 1],
c=dominant_archetypes,
alpha=0.6,
cmap="viridis",
)
# Add a color legend
legend1 = plt.legend(*scatter.legend_elements(), title="Dominant Archetype")
plt.gca().add_artist(legend1)
# Add grid lines for the simplex
for i in range(1, 10):
p = i / 10
# Line parallel to the bottom edge
plt.plot(
[p * 0.5, p + (1 - p) * 0.5],
[p * sqrt3_2, (1 - p) * 0],
"gray",
alpha=0.3,
)
# Line parallel to the left edge
plt.plot([0, p * 0.5], [p * 0, p * sqrt3_2], "gray", alpha=0.3)
# Line parallel to the right edge
plt.plot(
[p * 1, 0.5 + (1 - p) * 0.5],
[p * 0, (1 - p) * sqrt3_2],
"gray",
alpha=0.3,
)
plt.axis("equal")
plt.title("Samples in Simplex Space")
plt.axis("off")
plt.tight_layout()
plt.show()
[docs]
class BiarchetypalAnalysisVisualizer:
"""Visualization utilities for Biarchetypal Analysis."""
[docs]
@staticmethod
def plot_dual_archetypes_2d(
model: BiarchetypalAnalysis, X: np.ndarray, feature_names: list[str] | None = None
) -> None:
"""
Plot data and both sets of archetypes in 2D.
Args:
model: Fitted BiarchetypalAnalysis model
X: Original data
feature_names: Optional feature names for axis labels
"""
if X.shape[1] != 2:
raise ValueError("This plotting function is only for 2D data")
archetypes_first, archetypes_second = model.get_all_archetypes()
weights_first, weights_second = model.get_all_weights()
plt.figure(figsize=(12, 8))
# Plot data points
plt.scatter(X[:, 0], X[:, 1], alpha=0.4, color="gray", label="Data")
# Plot first set of archetypes
plt.scatter(
archetypes_first[:, 0],
archetypes_first[:, 1],
c="blue",
s=150,
marker="*",
label=f"Archetypes Set 1 (n={model.n_row_archetypes})",
)
# Plot second set of archetypes
plt.scatter(
archetypes_second[:, 0],
archetypes_second[:, 1],
c="red",
s=150,
marker="^",
label=f"Archetypes Set 2 (n={model.n_col_archetypes})",
)
# Connect points to their dominant archetypes in each set
for i in range(min(50, X.shape[0])): # Limit to 50 arrows for visual clarity
# Find dominant archetype in first set
max_idx_first = np.argmax(weights_first[i])
if weights_first[i, max_idx_first] > 0.5:
plt.arrow(
X[i, 0],
X[i, 1],
archetypes_first[max_idx_first, 0] - X[i, 0],
archetypes_first[max_idx_first, 1] - X[i, 1],
head_width=0.01,
head_length=0.02,
alpha=0.2,
color="blue",
linestyle="--",
)
# Find dominant archetype in second set
# Handle different shapes of weights_second
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# If weights_second has shape (n_col_archetypes, n_features)
weights_second_transposed = weights_second.T
if i < weights_second_transposed.shape[0]:
max_idx_second = np.argmax(weights_second_transposed[i])
if weights_second_transposed[i, max_idx_second] > 0.5:
plt.arrow(
X[i, 0],
X[i, 1],
archetypes_second[max_idx_second, 0] - X[i, 0],
archetypes_second[max_idx_second, 1] - X[i, 1],
head_width=0.01,
head_length=0.02,
alpha=0.2,
color="red",
linestyle=":",
)
elif len(weights_second.shape) == 2 and i < weights_second.shape[0]:
# For standard shape
max_idx_second = np.argmax(weights_second[i])
if weights_second[i, max_idx_second] > 0.5:
plt.arrow(
X[i, 0],
X[i, 1],
archetypes_second[max_idx_second, 0] - X[i, 0],
archetypes_second[max_idx_second, 1] - X[i, 1],
head_width=0.01,
head_length=0.02,
alpha=0.2,
color="red",
linestyle=":",
)
# Add feature names if provided
if feature_names and len(feature_names) >= 2:
plt.xlabel(feature_names[0])
plt.ylabel(feature_names[1])
else:
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.title("Data and Dual Archetype Sets")
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
[docs]
@staticmethod
def plot_biarchetypal_reconstruction(model: BiarchetypalAnalysis, X: np.ndarray) -> None:
"""
Plot original data vs. reconstructions from each archetype set and combined.
Args:
model: Fitted BiarchetypalAnalysis model
X: Original data matrix
"""
if X.shape[1] != 2:
raise ValueError("This plotting function is only for 2D data")
# Get weights for both archetype sets
weights_first, weights_second = model.get_all_weights()
archetypes_first, archetypes_second = model.get_all_archetypes()
# Create reconstructions
X_recon_first = np.matmul(weights_first, archetypes_first)
# Handle different shapes of weights_second
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# If weights_second has shape (n_col_archetypes, n_features)
weights_second_transposed = weights_second.T
if weights_second_transposed.shape[0] == weights_first.shape[0]:
X_recon_second = np.matmul(weights_second_transposed, archetypes_second)
else:
# Create dummy data with matching shape if shapes don't match
X_recon_second = np.zeros_like(X_recon_first)
else:
# For standard shape
X_recon_second = np.matmul(weights_second, archetypes_second)
# Check and adjust shape of X_recon_second if necessary
if X_recon_second.shape != X_recon_first.shape:
# Create dummy data with matching shape if shapes don't match
X_recon_second_temp = X_recon_second.copy()
X_recon_second = np.zeros_like(X_recon_first)
# Use original data where possible
min_rows = min(X_recon_second_temp.shape[0], X_recon_first.shape[0])
min_cols = min(X_recon_second_temp.shape[1], X_recon_first.shape[1])
X_recon_second[:min_rows, :min_cols] = X_recon_second_temp[:min_rows, :min_cols]
# Create combined reconstruction using mixture weight
if hasattr(model, "mixture_weight"):
X_recon_combined = (
model.mixture_weight * X_recon_first + (1 - model.mixture_weight) * X_recon_second
)
else:
# Use equal weights if mixture_weight doesn't exist
X_recon_combined = 0.5 * X_recon_first + 0.5 * X_recon_second
# Calculate reconstruction errors
error_first = np.linalg.norm(X - X_recon_first, ord="fro")
error_second = np.linalg.norm(X - X_recon_second, ord="fro")
error_combined = np.linalg.norm(X - X_recon_combined, ord="fro")
# Create plot with subplots
_, axes = plt.subplots(2, 2, figsize=(15, 12))
# Original data
axes[0, 0].scatter(X[:, 0], X[:, 1], alpha=0.7, label="Original")
axes[0, 0].set_title("Original Data")
axes[0, 0].grid(True, alpha=0.3)
# First archetype set reconstruction
axes[0, 1].scatter(X_recon_first[:, 0], X_recon_first[:, 1], alpha=0.7, color="blue")
axes[0, 1].scatter(
archetypes_first[:, 0],
archetypes_first[:, 1],
c="blue",
s=100,
marker="*",
label="Archetypes Set 1",
)
axes[0, 1].set_title(f"First Set Reconstruction\nError: {error_first:.4f}")
axes[0, 1].grid(True, alpha=0.3)
# Second archetype set reconstruction
axes[1, 0].scatter(X_recon_second[:, 0], X_recon_second[:, 1], alpha=0.7, color="red")
axes[1, 0].scatter(
archetypes_second[:, 0],
archetypes_second[:, 1],
c="red",
s=100,
marker="^",
label="Archetypes Set 2",
)
axes[1, 0].set_title(f"Second Set Reconstruction\nError: {error_second:.4f}")
axes[1, 0].grid(True, alpha=0.3)
# Combined reconstruction
axes[1, 1].scatter(
X_recon_combined[:, 0], X_recon_combined[:, 1], alpha=0.7, color="purple"
)
axes[1, 1].scatter(
archetypes_first[:, 0],
archetypes_first[:, 1],
c="blue",
s=100,
marker="*",
label="Archetypes Set 1",
)
axes[1, 1].scatter(
archetypes_second[:, 0],
archetypes_second[:, 1],
c="red",
s=100,
marker="^",
label="Archetypes Set 2",
)
# Check if mixture_weight exists
mixture_weight = model.mixture_weight if hasattr(model, "mixture_weight") else 0.5
axes[1, 1].set_title(
f"Combined Reconstruction (w={mixture_weight:.2f})\nError: {error_combined:.4f}"
)
axes[1, 1].grid(True, alpha=0.3)
# Add legend to the last plot
axes[1, 1].legend()
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_dual_membership_heatmap(model: BiarchetypalAnalysis, n_samples: int = 50) -> None:
"""
Plot heatmap of membership weights for both sets of archetypes.
Args:
model: Fitted BiarchetypalAnalysis model
n_samples: Number of samples to visualize
"""
weights_first, weights_second = model.get_all_weights()
# Select a subset of samples
n_samples = min(n_samples, weights_first.shape[0])
# Sort samples by their dominant archetype in first set only
max_weight_first_idx = np.argmax(weights_first, axis=1)
sorted_indices = np.argsort(max_weight_first_idx)[:n_samples]
# Get the subsets for plotting
weights_first_subset = weights_first[sorted_indices]
# Create figure with two subplots
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 8))
# Plot first set heatmap
sns.heatmap(
weights_first_subset,
ax=ax1,
cmap="viridis",
annot=True,
fmt=".2f",
xticklabels=[f"A1_{i}" for i in range(model.n_row_archetypes)],
yticklabels=False,
)
ax1.set_title(f"First Set Membership Weights (n={model.n_row_archetypes})")
ax1.set_xlabel("Archetypes (First Set)")
ax1.set_ylabel("Samples")
# Plot second set heatmap
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# Use transposed weights_second
weights_second_transposed = weights_second.T
if weights_second_transposed.shape[0] == weights_first.shape[0]:
weights_second_subset = weights_second_transposed[sorted_indices]
else:
# Create matrix of ones if shapes don't match
weights_second_subset = np.ones((len(sorted_indices), model.n_col_archetypes))
else:
# Create matrix of ones for n_col_archetypes=1 case
weights_second_subset = np.ones((len(sorted_indices), 1))
sns.heatmap(
weights_second_subset,
ax=ax2,
cmap="viridis",
annot=True,
fmt=".2f",
xticklabels=[f"A2_{i}" for i in range(model.n_col_archetypes)],
yticklabels=False,
)
ax2.set_title(f"Second Set Membership Weights (n={model.n_col_archetypes})")
ax2.set_xlabel("Archetypes (Second Set)")
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_mixture_effect(
model: BiarchetypalAnalysis, X: np.ndarray, mixture_steps: int = 5
) -> None:
"""
Plot the effect of different mixture weights between the two archetype sets.
Args:
model: Fitted BiarchetypalAnalysis model
X: Original data matrix
mixture_steps: Number of different mixture weights to try
"""
if X.shape[1] != 2:
raise ValueError("This plotting function is only for 2D data")
# Get weights and archetypes for both sets
weights_first, weights_second = model.get_all_weights()
archetypes_first, archetypes_second = model.get_all_archetypes()
# Create reconstructions for first and second sets
X_recon_first = np.matmul(weights_first, archetypes_first)
# Handle different shapes of weights_second
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# If weights_second has shape (n_col_archetypes, n_features)
weights_second_transposed = weights_second.T
if weights_second_transposed.shape[0] == weights_first.shape[0]:
X_recon_second = np.matmul(weights_second_transposed, archetypes_second)
else:
# Create dummy data with matching shape if shapes don't match
X_recon_second = np.zeros_like(X_recon_first)
# Display warning
print(
"Warning: Shape mismatch in weights_second. Using dummy data for second reconstruction."
)
else:
# For standard shape
X_recon_second = np.matmul(weights_second, archetypes_second)
# Check and adjust shape of X_recon_second if necessary
if X_recon_second.shape != X_recon_first.shape:
# Create dummy data with matching shape if shapes don't match
X_recon_second_temp = X_recon_second.copy()
X_recon_second = np.zeros_like(X_recon_first)
# Use original data where possible
min_rows = min(X_recon_second_temp.shape[0], X_recon_first.shape[0])
min_cols = min(X_recon_second_temp.shape[1], X_recon_first.shape[1])
X_recon_second[:min_rows, :min_cols] = X_recon_second_temp[:min_rows, :min_cols]
# Display warning message
print(
f"Warning: Shape mismatch between reconstructions. X_recon_first: {X_recon_first.shape}, X_recon_second: {X_recon_second_temp.shape}"
)
# Create figure with subplots
n_rows = (mixture_steps + 2) // 3 # Ceiling division
_, axes = plt.subplots(n_rows, min(3, mixture_steps), figsize=(15, 4 * n_rows))
# Flatten axes if necessary
if mixture_steps > 3:
axes = axes.flatten()
elif mixture_steps == 1:
axes = [axes] # Convert to list for single subplot case
# Original data for reference
X_original = X.copy()
# Try different mixture weights
for i in range(mixture_steps):
# Calculate mixture weight
mix_weight = i / (mixture_steps - 1) if mixture_steps > 1 else 0.5
# Create mixed reconstruction
X_mixed = mix_weight * X_recon_first + (1 - mix_weight) * X_recon_second
# Calculate error
error = np.linalg.norm(X_original - X_mixed, ord="fro")
# Plot
ax = axes[i] if mixture_steps > 1 else axes
ax.scatter(
X_original[:, 0], X_original[:, 1], alpha=0.2, color="gray", label="Original"
)
ax.scatter(
X_mixed[:, 0], X_mixed[:, 1], alpha=0.7, color="purple", label="Reconstructed"
)
ax.scatter(
archetypes_first[:, 0],
archetypes_first[:, 1],
c="blue",
s=80,
marker="*",
label="Set 1",
)
ax.scatter(
archetypes_second[:, 0],
archetypes_second[:, 1],
c="red",
s=80,
marker="^",
label="Set 2",
)
ax.set_title(f"w={mix_weight:.2f}, Error={error:.4f}")
# Only add legend to the first plot
if i == 0:
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
[docs]
@staticmethod
def plot_dual_simplex_2d(model: BiarchetypalAnalysis, n_samples: int = 200) -> None:
"""
Plot samples in separate 2D simplex spaces for each archetype set (only works for 3 archetypes per set).
Args:
model: Fitted BiarchetypalAnalysis model
n_samples: Number of samples to plot
"""
# Relaxed condition: at least one of the archetype sets must have exactly 3 archetypes
if model.n_row_archetypes != 3 and model.n_col_archetypes != 3:
raise ValueError(
"This simplex plot requires at least one set to have exactly 3 archetypes"
)
# Get weights for both sets
weights_first, weights_second = model.get_all_weights()
# Select a subset of samples if needed
if n_samples < weights_first.shape[0]:
indices = np.random.choice(weights_first.shape[0], n_samples, replace=False)
weights_first_subset = weights_first[indices]
# Handle different shapes of weights_second
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# If weights_second has shape (n_col_archetypes, n_features)
weights_second_transposed = weights_second.T
if weights_second_transposed.shape[0] == weights_first.shape[0]:
weights_second_subset = weights_second_transposed[indices]
else:
# Create dummy data if shapes don't match
weights_second_subset = (
np.ones((len(indices), model.n_col_archetypes)) / model.n_col_archetypes
)
else:
# For standard shape
weights_second_subset = weights_second[indices]
else:
weights_first_subset = weights_first
# Handle different shapes of weights_second
if len(weights_second.shape) == 2 and weights_second.shape[0] == model.n_col_archetypes:
# If weights_second has shape (n_col_archetypes, n_features)
weights_second_transposed = weights_second.T
if weights_second_transposed.shape[0] == weights_first.shape[0]:
weights_second_subset = weights_second_transposed
else:
# Create dummy data if shapes don't match
weights_second_subset = (
np.ones((weights_first.shape[0], model.n_col_archetypes))
/ model.n_col_archetypes
)
else:
# For standard shape
weights_second_subset = weights_second
# Handle case where first archetype set doesn't have exactly 3 archetypes
if model.n_row_archetypes != 3:
# Create dummy data
weights_first_subset = np.ones((weights_first_subset.shape[0], 3)) / 3
# Handle case where second archetype set doesn't have exactly 3 archetypes
if model.n_col_archetypes != 3:
# Create dummy data
weights_second_subset = np.ones((weights_second_subset.shape[0], 3)) / 3
# Set up the figure with two subplots
_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
# Convert barycentric coordinates to 2D for visualization
sqrt3_2 = np.sqrt(3) / 2
triangle_vertices = np.array(
[
[0, 0], # Archetype 0 at origin
[1, 0], # Archetype 1 at (1,0)
[0.5, sqrt3_2], # Archetype 2 at (0.5, sqrt(3)/2)
]
)
# Transform weights to 2D coordinates for both sets
points_2d_first = np.dot(weights_first_subset, triangle_vertices)
points_2d_second = np.dot(weights_second_subset, triangle_vertices)
# Create dominant archetype colormaps for each set
dominant_archetypes_first = np.argmax(weights_first_subset, axis=1)
dominant_archetypes_second = np.argmax(weights_second_subset, axis=1)
# Plot first set simplex
ax1.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-")
ax1.scatter(
points_2d_first[:, 0],
points_2d_first[:, 1],
c=dominant_archetypes_first,
alpha=0.6,
cmap="Blues",
)
ax1.text(-0.05, -0.05, "A1_0", ha="right", color="blue")
ax1.text(1.05, -0.05, "A1_1", ha="left", color="blue")
ax1.text(0.5, sqrt3_2 + 0.05, "A1_2", ha="center", color="blue")
ax1.set_title(
"First Archetype Set Simplex" + (" (Dummy)" if model.n_row_archetypes != 3 else "")
)
ax1.axis("equal")
ax1.axis("off")
# Plot second set simplex
ax2.plot([0, 1, 0.5, 0], [0, 0, sqrt3_2, 0], "k-")
ax2.scatter(
points_2d_second[:, 0],
points_2d_second[:, 1],
c=dominant_archetypes_second,
alpha=0.6,
cmap="Reds",
)
ax2.text(-0.05, -0.05, "A2_0", ha="right", color="red")
ax2.text(1.05, -0.05, "A2_1", ha="left", color="red")
ax2.text(0.5, sqrt3_2 + 0.05, "A2_2", ha="center", color="red")
ax2.set_title(
"Second Archetype Set Simplex" + (" (Dummy)" if model.n_col_archetypes != 3 else "")
)
ax2.axis("equal")
ax2.axis("off")
# Add grid lines for the simplex
for ax in [ax1, ax2]:
for i in range(1, 10):
p = i / 10
# Line parallel to the bottom edge
ax.plot(
[p * 0.5, p + (1 - p) * 0.5],
[p * sqrt3_2, (1 - p) * 0],
"gray",
alpha=0.3,
)
# Line parallel to the left edge
ax.plot([0, p * 0.5], [p * 0, p * sqrt3_2], "gray", alpha=0.3)
# Line parallel to the right edge
ax.plot(
[p * 1, 0.5 + (1 - p) * 0.5],
[p * 0, (1 - p) * sqrt3_2],
"gray",
alpha=0.3,
)
plt.tight_layout()
plt.show()