Examples
This section provides comprehensive examples demonstrating the application of archetypal analysis in various domains.
Synthetic Data Example
This example illustrates the fundamental concepts of archetypal analysis using synthetic data:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from archetypax import ArchetypalAnalysis
# Generate synthetic data with clear cluster structure
X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.60, random_state=42)
# Fit archetypal analysis model
model = ArchetypalAnalysis(n_archetypes=4, random_state=42)
model.fit(X)
# Get archetypes and transform data
archetypes = model.archetypes_
weights = model.transform(X)
# Visualize results
plt.figure(figsize=(12, 10))
# Plot original data points
plt.scatter(X[:, 0], X[:, 1], alpha=0.5, label='Data points')
# Plot archetypes
plt.scatter(archetypes[:, 0], archetypes[:, 1], s=200, c='red',
marker='*', label='Archetypes', edgecolors='black')
# Draw convex hull of archetypes
from scipy.spatial import ConvexHull
hull = ConvexHull(archetypes)
for simplex in hull.simplices:
plt.plot(archetypes[simplex, 0], archetypes[simplex, 1], 'k-', lw=2)
plt.title('Archetypal Analysis on Synthetic Data', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.show()
Text Data Analysis
Archetypal analysis can be applied to text data to discover archetypal topics or document types:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
from archetypax import ArchetypalAnalysis
# Load text data from 20 newsgroups dataset
categories = [
'comp.graphics', 'comp.os.ms-windows.misc',
'comp.sys.mac.hardware', 'comp.sys.ibm.pc.hardware',
'rec.autos', 'rec.motorcycles',
'rec.sport.baseball', 'rec.sport.hockey'
]
newsgroups = fetch_20newsgroups(
subset='train',
categories=categories,
remove=('headers', 'footers', 'quotes'),
random_state=42
)
# Extract features using TF-IDF
vectorizer = TfidfVectorizer(
max_features=5000,
min_df=5,
max_df=0.7,
stop_words='english'
)
X_tfidf = vectorizer.fit_transform(newsgroups.data)
feature_names = vectorizer.get_feature_names_out()
# Reduce dimensionality with SVD for computational efficiency
svd = TruncatedSVD(n_components=100, random_state=42)
X_svd = svd.fit_transform(X_tfidf)
# Apply archetypal analysis
model = ArchetypalAnalysis(n_archetypes=8, random_state=42)
model.fit(X_svd)
# Get archetypes and transform back to TF-IDF space
archetypes_svd = model.archetypes_
archetypes_tfidf = svd.inverse_transform(archetypes_svd)
# Get weights for each document
weights = model.transform(X_svd)
# Function to extract top terms for each archetype
def get_top_terms(archetype_vector, feature_names, top_n=15):
# Get indices of top terms
top_indices = archetype_vector.argsort()[-top_n:][::-1]
# Return top terms and their weights
return [(feature_names[i], archetype_vector[i]) for i in top_indices]
# Print top terms for each archetype
print("Top terms for each archetype:")
for i, archetype in enumerate(archetypes_tfidf):
print(f"\nArchetype {i+1}:")
top_terms = get_top_terms(archetype, feature_names)
for term, weight in top_terms:
print(f" {term}: {weight:.4f}")
# Visualize document weights for each archetype
plt.figure(figsize=(14, 10))
sns.heatmap(
weights,
cmap='viridis',
xticklabels=[f'Archetype {i+1}' for i in range(weights.shape[1])],
yticklabels=False
)
plt.title('Document Weights for Each Archetype', fontsize=16)
plt.xlabel('Archetypes', fontsize=14)
plt.ylabel('Documents', fontsize=14)
plt.tight_layout()
plt.show()
# Visualize category distribution for each archetype
# Assign each document to its dominant archetype
dominant_archetype = np.argmax(weights, axis=1)
# Create a DataFrame with document categories and dominant archetypes
df = pd.DataFrame({
'category': [categories[newsgroups.target[i]] for i in range(len(newsgroups.target))],
'dominant_archetype': [f'Archetype {i+1}' for i in dominant_archetype]
})
# Count documents by category and archetype
category_counts = df.groupby(['dominant_archetype', 'category']).size().unstack(fill_value=0)
# Normalize by archetype to get percentages
category_percentages = category_counts.div(category_counts.sum(axis=1), axis=0) * 100
# Plot category distribution
plt.figure(figsize=(16, 12))
category_percentages.plot(
kind='bar',
stacked=True,
colormap='tab10',
figsize=(16, 10)
)
plt.title('Category Distribution for Each Archetype', fontsize=16)
plt.xlabel('Archetype', fontsize=14)
plt.ylabel('Percentage of Documents', fontsize=14)
plt.legend(title='Category', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()
# Visualize archetypes in 2D space
# Project archetypes and data to 2D using t-SNE
from sklearn.manifold import TSNE
# Apply t-SNE to SVD-reduced data
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
X_tsne = tsne.fit_transform(X_svd)
archetypes_tsne = tsne.fit_transform(archetypes_svd)
# Create a scatter plot
plt.figure(figsize=(14, 10))
# Plot documents colored by category
for i, category in enumerate(categories):
indices = np.where(newsgroups.target == i)[0]
plt.scatter(
X_tsne[indices, 0],
X_tsne[indices, 1],
alpha=0.5,
label=category,
s=30
)
# Plot archetypes
plt.scatter(
archetypes_tsne[:, 0],
archetypes_tsne[:, 1],
s=300,
c='black',
marker='*',
label='Archetypes',
edgecolors='white',
linewidths=1.5
)
# Add archetype labels
for i, (x, y) in enumerate(archetypes_tsne):
plt.annotate(
f'A{i+1}',
(x, y),
fontsize=12,
fontweight='bold',
color='white',
ha='center',
va='center'
)
plt.title('t-SNE Projection of Documents and Archetypes', fontsize=16)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
# Analyze a specific document
# Find a document with high weight for a particular archetype
archetype_idx = 0 # Change this to analyze different archetypes
top_doc_idx = np.argsort(weights[:, archetype_idx])[-1]
print(f"\nExample document with high weight for Archetype {archetype_idx+1}:")
print(f"Category: {categories[newsgroups.target[top_doc_idx]]}")
print(f"Weights: {weights[top_doc_idx]}")
print("\nDocument text:")
print(newsgroups.data[top_doc_idx][:500] + "...") # Show first 500 chars
Image Data Analysis
Archetypal analysis can be applied to image data to extract representative patterns:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_olivetti_faces
from archetypax import ArchetypalAnalysis
# Load face dataset
faces = fetch_olivetti_faces()
X = faces.data # (400, 4096) - 400 images, 64x64 pixels flattened
# Fit archetypal analysis
model = ArchetypalAnalysis(n_archetypes=10, random_state=42)
model.fit(X)
# Get archetypes (archetypal faces)
archetypes = model.archetypes_
# Visualize archetypal faces
fig, axes = plt.subplots(2, 5, figsize=(15, 6),
subplot_kw={'xticks': [], 'yticks': []})
for i, ax in enumerate(axes.flat):
# Reshape to 64x64 image
face = archetypes[i].reshape(64, 64)
ax.imshow(face, cmap='gray')
ax.set_title(f'Archetype {i+1}')
plt.suptitle('Archetypal Faces', fontsize=16)
plt.tight_layout()
plt.subplots_adjust(top=0.9)
plt.show()
# Reconstruct a face using archetypes
sample_idx = 150
sample_face = X[sample_idx]
sample_weights = model.transform(sample_face.reshape(1, -1))
reconstructed_face = model.inverse_transform(sample_weights)
# Visualize original vs reconstructed
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5),
subplot_kw={'xticks': [], 'yticks': []})
ax1.imshow(sample_face.reshape(64, 64), cmap='gray')
ax1.set_title('Original Face')
ax2.imshow(reconstructed_face.reshape(64, 64), cmap='gray')
ax2.set_title('Reconstructed Face')
plt.tight_layout()
plt.show()
Genomic Data Analysis
Archetypal analysis can identify archetypal expression patterns in genomic data:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from archetypax import ArchetypalAnalysis
# Simulate gene expression data
# In practice, you would load real data
np.random.seed(42)
n_samples = 200 # patients
n_genes = 1000 # genes
# Simulate gene expression matrix
X = np.random.exponential(scale=1.0, size=(n_samples, n_genes))
# Add some structure to the data
for i in range(0, n_samples, 50):
X[i:i+50, i//50*250:(i//50+1)*250] *= 3
# Fit archetypal analysis
model = ArchetypalAnalysis(n_archetypes=4, random_state=42)
model.fit(X)
# Get archetypes and weights
archetypes = model.archetypes_
weights = model.transform(X)
# Visualize archetype weights for each sample
plt.figure(figsize=(12, 8))
sns.heatmap(weights, cmap='viridis',
xticklabels=[f'Archetype {i+1}' for i in range(weights.shape[1])],
yticklabels=False)
plt.title('Sample Weights for Each Archetype', fontsize=16)
plt.xlabel('Archetypes', fontsize=14)
plt.ylabel('Samples', fontsize=14)
plt.tight_layout()
plt.show()
# Visualize gene expression patterns in archetypes
plt.figure(figsize=(15, 10))
for i in range(4):
plt.subplot(2, 2, i+1)
plt.plot(archetypes[i])
plt.title(f'Archetype {i+1} Gene Expression Pattern', fontsize=14)
plt.xlabel('Gene Index', fontsize=12)
plt.ylabel('Expression Level', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()
Market Segmentation Example
Archetypal analysis can be used for customer segmentation in marketing:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from archetypax import ArchetypalAnalysis
# Simulate customer data
np.random.seed(42)
n_customers = 500
# Features: age, income, spending, online_activity, store_visits
X = np.zeros((n_customers, 5))
# Generate different customer profiles
X[:125, 0] = np.random.normal(25, 5, 125) # Young
X[125:250, 0] = np.random.normal(35, 5, 125) # Middle-aged
X[250:375, 0] = np.random.normal(45, 5, 125) # Older middle-aged
X[375:, 0] = np.random.normal(65, 5, 125) # Senior
X[:125, 1] = np.random.normal(40000, 10000, 125) # Lower income
X[125:250, 1] = np.random.normal(70000, 15000, 125) # Middle income
X[250:375, 1] = np.random.normal(100000, 20000, 125) # Upper middle income
X[375:, 1] = np.random.normal(60000, 15000, 125) # Retirement income
# Other features with correlations to age/income
for i in range(n_customers):
age_factor = X[i, 0] / 40 # Normalized by average age
income_factor = X[i, 1] / 70000 # Normalized by average income
# Spending (younger and higher income spend more)
X[i, 2] = np.random.normal(5000 * (2 - age_factor) * income_factor, 1000)
# Online activity (younger are more active online)
X[i, 3] = np.random.normal(10 * (2 - age_factor), 2)
# Store visits (older visit stores more)
X[i, 4] = np.random.normal(20 * age_factor, 5)
# Normalize features
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Fit archetypal analysis
model = ArchetypalAnalysis(n_archetypes=4, random_state=42)
model.fit(X_scaled)
# Get archetypes and weights
archetypes = model.archetypes_
archetypes_original = scaler.inverse_transform(archetypes)
weights = model.transform(X_scaled)
# Create feature names for better visualization
feature_names = ['Age', 'Income', 'Spending', 'Online Activity', 'Store Visits']
# Visualize archetypes
plt.figure(figsize=(14, 10))
# Create a radar chart for each archetype
from matplotlib.path import Path
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
def radar_factory(num_vars, frame='circle'):
theta = np.linspace(0, 2*np.pi, num_vars, endpoint=False)
class RadarAxes(plt.PolarAxes):
name = 'radar'
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.set_theta_zero_location('N')
def fill(self, *args, **kwargs):
return super().fill_between(*args, **kwargs)
def plot(self, *args, **kwargs):
lines = super().plot(*args, **kwargs)
self._close_polygon(lines)
return lines
def _close_polygon(self, lines):
for line in lines:
x, y = line.get_data()
if x[0] != x[-1]:
x = np.concatenate((x, [x[0]]))
y = np.concatenate((y, [y[0]]))
line.set_data(x, y)
def set_varlabels(self, labels):
self.set_thetagrids(np.degrees(theta), labels)
def _gen_axes_patch(self):
return plt.Circle((0.5, 0.5), 0.5)
register_projection(RadarAxes)
return theta
from matplotlib.projections import register_projection
# Normalize archetype values for radar chart
archetypes_radar = np.zeros_like(archetypes_original)
for i in range(archetypes_original.shape[1]):
min_val = np.min(X[:, i])
max_val = np.max(X[:, i])
archetypes_radar[:, i] = (archetypes_original[:, i] - min_val) / (max_val - min_val)
# Create radar chart
theta = radar_factory(len(feature_names))
fig, axes = plt.subplots(figsize=(15, 12), nrows=2, ncols=2,
subplot_kw=dict(projection='radar'))
colors = ['b', 'g', 'r', 'c']
for ax, color, archetype, archetype_orig in zip(axes.flat, colors,
archetypes_radar,
archetypes_original):
ax.plot(theta, archetype, color=color)
ax.fill(theta, archetype, facecolor=color, alpha=0.25)
ax.set_varlabels(feature_names)
# Add values in original scale
for i, value in enumerate(archetype_orig):
angle = i * 2 * np.pi / len(feature_names)
ax.text(angle, 1.15, f"{value:.0f}",
horizontalalignment='center', size='small')
# Add titles
titles = ['Young Digital Shoppers', 'Affluent Professionals',
'Traditional Shoppers', 'Senior Conservatives']
for ax, title in zip(axes.flat, titles):
ax.set_title(title, weight='bold', size='medium', position=(0.5, 1.1),
horizontalalignment='center', verticalalignment='center')
plt.tight_layout()
plt.subplots_adjust(wspace=0.5, hspace=0.5)
plt.show()
# Assign each customer to dominant archetype
dominant_archetype = np.argmax(weights, axis=1)
# Visualize customer segments
plt.figure(figsize=(12, 10))
# Create scatter plot of age vs income colored by dominant archetype
plt.scatter(X[:, 0], X[:, 1], c=dominant_archetype, cmap='viridis',
alpha=0.7, s=50)
plt.colorbar(ticks=range(4), label='Dominant Archetype')
plt.xlabel('Age', fontsize=14)
plt.ylabel('Income', fontsize=14)
plt.title('Customer Segmentation by Age and Income', fontsize=16)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()