# --- Part 1: Imports, Constants, Dataset Loading,...

Created on: July 14, 2025

Answered using O4 Mini High by Chat01

Question

--- Part 1: Imports, Constants, Dataset Loading, Dirichlet Split ---

from sklearn.cluster import KMeans
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np
from collections import defaultdict, Counter
from copy import deepcopy
import logging
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from medmnist import INFO, PathMNIST
from medmnist import dict as medmnist_dict
from torch.cuda.amp import GradScaler, autocast
import ot # For Wasserstein barycenter

Add these imports at the top of your file with other imports

import json
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

--- Logging Setup ---

logging.basicConfig(
level=logging.INFO,
format='%(levelname)s : %(message)s',
handlers=[logging.FileHandler('14th-july-Enhanced-kmeans-perplexity.log'), logging.StreamHandler()]
)
logger = logging.getLogger(name)

--- Utility Functions ---

def get_class_distribution(dataset, indices):
"""Get class distribution for given indices"""
labels = [dataset[i][1] for i in indices] # no .item()
return Counter(labels)

--- Constants ---

BATCH_SIZE = 64
NUM_CLASSES = 9
NUM_CLIENTS = 10
LOCAL_EPOCHS = 5
COMMUNICATION_ROUNDS = 100
ALPHA = 0.5
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
logger.info("Using ImageNet pre-trained ResNet18 weights for feature extraction")

--- Enhanced Data Transforms for Augmentation ---

Strong augmentation for clients with smaller datasets

strong_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

Standard augmentation

standard_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

Test transform (no augmentation)

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])

info = INFO['pathmnist']
DataClass = medmnist_dict[info['python_class']]

Load datasets without transforms initially

from PIL import Image

Load combined train + val images and labels

train_ds = DataClass(split='train', download=True)
val_ds = DataClass(split='val', download=True)

combined_imgs = np.concatenate([train_ds.imgs, val_ds.imgs], axis=0)
combined_labels = np.concatenate([train_ds.labels, val_ds.labels], axis=0).squeeze()

Define a simple dataset wrapper for use with existing code

class CombinedDataset(torch.utils.data.Dataset):
def init(self, imgs, labels, transform=None):
self.imgs = imgs
self.labels = labels
self.transform = transform

text
def __len__(self): return len(self.labels) def __getitem__(self, idx): img = Image.fromarray(self.imgs[idx]) label = int(self.labels[idx]) if self.transform: img = self.transform(img) return img, label

Now use this everywhere instead of full_train_dataset/full_val_dataset

full_combined_dataset = CombinedDataset(combined_imgs, combined_labels, transform=None)
full_train_dataset=full_combined_dataset
full_val_dataset=full_combined_dataset
full_test_dataset = DataClass(split='test', transform=transform_test, download=True)

--- Custom Dataset Class for Dynamic Augmentation ---

class AugmentedDataset(Dataset):
def init(self, base_dataset, indices, transform, target_iterations, batch_size):
self.base_dataset = base_dataset
self.indices = indices
self.transform = transform
self.target_size = target_iterations * batch_size
self.actual_size = len(indices)

text
# Calculate how many times we need to repeat the dataset self.repeat_factor = max(1, self.target_size // self.actual_size) if self.target_size % self.actual_size > 0: self.repeat_factor += 1 logger.info(f"Dataset with {self.actual_size} samples expanded to {self.target_size} samples (repeat factor: {self.repeat_factor})") def __len__(self): return self.target_size def __getitem__(self, idx): # Map the index to the original dataset original_idx = self.indices[idx % self.actual_size] image, label = self.base_dataset[original_idx] # Apply transform if self.transform: image = self.transform(image) return image, label

--- Dirichlet Data Split Function ---

def dirichlet_partition(dataset, alpha, num_clients, min_samples_per_class=20):
labels = np.array([int(dataset[i][1]) for i in range(len(dataset))])
num_classes = np.max(labels) + 1
class_indices = [np.where(labels == y)[0] for y in range(num_classes)]
client_indices = [[] for _ in range(num_clients)]

text
for c, indices in enumerate(class_indices): np.random.shuffle(indices) # Check if we have enough samples for minimum allocation required_samples = min_samples_per_class * num_clients if len(indices) < required_samples: logger.warning(f"Class {c} has only {len(indices)} samples, but needs {required_samples}. Distributing available samples equally.") # Distribute available samples as equally as possible samples_per_client = len(indices) // num_clients remainder = len(indices) % num_clients start_idx = 0 for client_id in range(num_clients): end_idx = start_idx + samples_per_client + (1 if client_id < remainder else 0) client_indices[client_id].extend(indices[start_idx:end_idx].tolist()) start_idx = end_idx continue # Step 1: First distribute minimum samples to each client remaining_indices = list(indices) for client_id in range(num_clients): allocated_samples = remaining_indices[:min_samples_per_class] client_indices[client_id].extend(allocated_samples) remaining_indices = remaining_indices[min_samples_per_class:] # Step 2: Apply Dirichlet distribution to remaining samples if len(remaining_indices) > 0: proportions = np.random.dirichlet([alpha] * num_clients) # Convert proportions to actual sample counts sample_counts = (proportions * len(remaining_indices)).astype(int) # Handle rounding errors - distribute remaining samples remaining_to_distribute = len(remaining_indices) - sample_counts.sum() if remaining_to_distribute > 0: # Add remaining samples to clients with highest proportions top_clients = np.argsort(proportions)[-remaining_to_distribute:] for client_id in top_clients: sample_counts[client_id] += 1 # Distribute the remaining samples according to Dirichlet proportions start_idx = 0 for client_id in range(num_clients): end_idx = start_idx + sample_counts[client_id] if end_idx > start_idx: # Only add if there are samples to add client_indices[client_id].extend(remaining_indices[start_idx:end_idx]) start_idx = end_idx return client_indices

--- Global Test Set (Use entire test dataset) ---

global_test_dataset = full_test_dataset
global_test_loader = DataLoader(global_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

--- Create Client Datasets using Dirichlet ---

#train_client_indices = dirichlet_partition(full_train_dataset, alpha=ALPHA, num_clients=NUM_CLIENTS, min_samples_per_class=20)
#val_client_indices = dirichlet_partition(full_val_dataset, alpha=ALPHA, num_clients=NUM_CLIENTS, min_samples_per_class=5)

--- Load Client Datasets from JSON files ---

'''logger.info("Loading client indices from JSON files...")

Load train client indices

with open('train_client_indices.json', 'r') as f:
train_client_indices = json.load(f)

Load val client indices

with open('val_client_indices.json', 'r') as f:
val_client_indices = json.load(f)'''
logger.info("Loading client indices from JSON files...")

Load train client indices

with open('0.5_dirichlet_pathmnist/train_client_indices.json', 'r') as f:
train_client_indices_raw = json.load(f)

Load val client indices

with open('0.5_dirichlet_pathmnist/val_client_indices.json', 'r') as f:
val_client_indices_raw = json.load(f)

Convert to the expected format (list of lists)

train_client_indices = []
val_client_indices = []

If JSON contains a dictionary with client keys, convert to list

if isinstance(train_client_indices_raw, dict):
for client_id in range(NUM_CLIENTS):
client_key = str(client_id) # JSON keys are strings
if client_key in train_client_indices_raw:
train_client_indices.append(train_client_indices_raw[client_key])
else:
logger.error(f"Client {client_id} not found in train_client_indices.json")
raise KeyError(f"Client {client_id} not found in JSON")
else:
# If JSON is already a list of lists
train_client_indices = train_client_indices_raw

if isinstance(val_client_indices_raw, dict):
for client_id in range(NUM_CLIENTS):
client_key = str(client_id)
if client_key in val_client_indices_raw:
val_client_indices.append(val_client_indices_raw[client_key])
else:
logger.error(f"Client {client_id} not found in val_client_indices.json")
raise KeyError(f"Client {client_id} not found in JSON")
else:
val_client_indices = val_client_indices_raw

Verify the loaded data

logger.info(f"Loaded train indices for {len(train_client_indices)} clients")
logger.info(f"Loaded val indices for {len(val_client_indices)} clients")

Log sample sizes for verification

for client_id in range(NUM_CLIENTS):
logger.info(f"Client {client_id + 1}: Train samples = {len(train_client_indices[client_id])}, Val samples = {len(val_client_indices[client_id])}")
logger.info(f"Loaded train indices for {len(train_client_indices)} clients")
logger.info(f"Loaded val indices for {len(val_client_indices)} clients")

Calculate maximum dataset size for iteration equalization

max_train_size = max(len(indices) for indices in train_client_indices)
max_val_size = max(len(indices) for indices in val_client_indices)
max_train_iterations = max_train_size // BATCH_SIZE
max_val_iterations = max_val_size // BATCH_SIZE

logger.info(f"Maximum training dataset size: {max_train_size}")
logger.info(f"Maximum validation dataset size: {max_val_size}")
logger.info(f"Target iterations per epoch: {max_train_iterations} (train), {max_val_iterations} (val)")

Create client datasets with augmentation

client_datasets = []
for i in range(NUM_CLIENTS):
train_size = len(train_client_indices[i])
val_size = len(val_client_indices[i])

text
# Determine if client needs strong augmentation (smaller datasets) if train_size < max_train_size * 0.7: # If less than 70% of max size train_transform = strong_transform logger.info(f"Client {i+1}: Using strong augmentation (dataset size: {train_size})") else: train_transform = standard_transform logger.info(f"Client {i+1}: Using standard augmentation (dataset size: {train_size})") # Create augmented datasets train_dataset = AugmentedDataset( full_train_dataset, train_client_indices[i], train_transform, max_train_iterations, BATCH_SIZE ) val_dataset = AugmentedDataset( full_val_dataset, val_client_indices[i], transform_test, max_val_iterations, BATCH_SIZE ) client_datasets.append({ 'train': train_dataset, 'val': val_dataset, 'train_loader': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True), 'val_loader': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) })

def calculate_client_data_percentages(train_indices, val_indices):
"""Calculate percentage of total data each client has"""

text
total_train_samples = sum(len(indices) for indices in train_indices) total_val_samples = sum(len(indices) for indices in val_indices) train_percentages = [] val_percentages = [] for client_id in range(NUM_CLIENTS): train_pct = (len(train_indices[client_id]) / total_train_samples) * 100 val_pct = (len(val_indices[client_id]) / total_val_samples) * 100 train_percentages.append(train_pct) val_percentages.append(val_pct) return train_percentages, val_percentages

--- Log Class Distributions ---

Calculate percentages for use in training loop

#train_percentages, val_percentages = calculate_client_data_percentages(train_client_indices, val_client_indices)

--- Enhanced Data Distribution Logging and Visualization ---

Log individual client distributions (existing format for compatibility)

logger.info("\n" + "=" * 80)
logger.info("INDIVIDUAL CLIENT DISTRIBUTIONS")
logger.info("=" * 80)
for client_id in range(NUM_CLIENTS):
train_dist = get_class_distribution(full_train_dataset, train_client_indices[client_id])
val_dist = get_class_distribution(full_val_dataset, val_client_indices[client_id])
logger.info(f"Client {client_id + 1} Train Dist: {dict(train_dist)} (Original: {len(train_client_indices[client_id])}, Augmented: {len(client_datasets[client_id]['train'])})")
logger.info(f"Client {client_id + 1} Val Dist: {dict(val_dist)} (Original: {len(val_client_indices[client_id])}, Augmented: {len(client_datasets[client_id]['val'])})")

Calculate and log data percentages

train_percentages, val_percentages = calculate_client_data_percentages(train_client_indices, val_client_indices)
logger.info("\n" + "=" * 80)
logger.info("CLIENT DATA PERCENTAGES")
logger.info("=" * 80)
for client_id in range(NUM_CLIENTS):
logger.info(f"Client {client_id + 1}: Train {train_percentages[client_id]:.1f}%, Val {val_percentages[client_id]:.1f}%")

Log global test distribution

global_test_labels = [full_test_dataset[i][1].item() for i in range(len(full_test_dataset))]
global_test_dist = Counter(global_test_labels)
logger.info(f"\nGlobal Test Distribution (Total: {len(full_test_dataset)}): {dict(global_test_dist)}")

Initialize checkpoint tracking

client_best_val_accs = [0.0] * NUM_CLIENTS # Track best validation accuracy for each client
client_best_checkpoints = [None] * NUM_CLIENTS # Track best checkpoint paths for each client

--- Part 2: VQ Model, Decoder, Aggregation, Parameter Utilities ---

--- GumbelSoftmax VQ Layer ---

class GumbelSoftmaxVQ(nn.Module):
def init(self, num_embeddings=512, embedding_dim=512, temp_init=1.0):
super().init()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
self.proj = nn.Conv2d(embedding_dim, num_embeddings, 1)
self.temperature = temp_init

text
def forward(self, x): B, C, H, W = x.shape logits = self.proj(x).permute(0, 2, 3, 1).reshape(-1, self.num_embeddings) soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, hard=True) z_q = soft_one_hot @ self.embeddings z_q = z_q.view(B, H, W, C).permute(0, 3, 1, 2) avg_probs = soft_one_hot.mean(dim=0) entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-10)) perplexity = torch.exp(entropy) commitment_loss = F.mse_loss(z_q.detach(), x) return z_q, commitment_loss, perplexity

--- VQ + Decoder ResNet Classifier ---

class VQDecoderResNetClassifier(nn.Module):
def init(self, num_classes=NUM_CLASSES, embedding_dim=512):
super().init()
# Use ImageNet pre-trained weights
self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
self.encoder = nn.Sequential(*list(self.resnet.children())[:-2])
self.vq_layer = GumbelSoftmaxVQ(num_embeddings=512, embedding_dim=embedding_dim)

text
self.decoder_conv1 = nn.Sequential( nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1), nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True) ) self.decoder_conv2 = nn.Sequential( nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1), nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True) ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Linear(embedding_dim, num_classes) def forward(self, x): features = self.encoder(x) quantized, vq_loss, perplexity = self.vq_layer(features) x1 = self.decoder_conv1(quantized) x2 = self.decoder_conv2(x1 + quantized) decoded = x2 + x1 pooled = self.avgpool(decoded) pooled = torch.flatten(pooled, 1) output = self.classifier(pooled) return output, vq_loss, decoded, features, perplexity

--- Federated Averaging (FedAvg) ---

def federated_averaging(client_params_list, param_names):
aggregated_params = {}
for param_name in param_names:
stacked_params = torch.stack([params[param_name] for params in client_params_list])
aggregated_params[param_name] = torch.mean(stacked_params, dim=0)
return aggregated_params

--- Wasserstein Barycenter Aggregation for VQ Embeddings ---

--- K-means Clustering Aggregation for VQ Embeddings ---

def kmeans_clustering_vq(client_vq_params_list, n_clusters=None, random_state=42):
"""
Aggregate VQ embeddings using K-means clustering

text
Args: client_vq_params_list: List of VQ parameter dictionaries from clients n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility Returns: aggregated_vq: Aggregated VQ parameters """ aggregated_vq = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using K-means with {n_clusters} clusters") # Concatenate all embeddings from all clients all_embeddings = np.concatenate(embeddings_list, axis=0) # Shape: (n_clients * n_embeddings, embedding_dim) # Apply K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) kmeans.fit(all_embeddings) # Use cluster centers as the aggregated embeddings aggregated_embeddings = kmeans.cluster_centers_ # Ensure we have the correct number of embeddings if aggregated_embeddings.shape[0] != n_embeddings: logger.warning(f"K-means produced {aggregated_embeddings.shape[0]} clusters, but need {n_embeddings} embeddings") if aggregated_embeddings.shape[0] > n_embeddings: # Take the first n_embeddings clusters aggregated_embeddings = aggregated_embeddings[:n_embeddings] else: # Pad with random embeddings if we have fewer clusters remaining = n_embeddings - aggregated_embeddings.shape[0] padding = np.random.randn(remaining, embedding_dim) * 0.1 aggregated_embeddings = np.concatenate([aggregated_embeddings, padding], axis=0) aggregated_vq[param_name] = torch.tensor(aggregated_embeddings, dtype=torch.float32) logger.info(f"Aggregated {param_name}: {aggregated_embeddings.shape}") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq

--- Alternative: Weighted K-means based on client data sizes ---

def weighted_kmeans_clustering_vq(client_vq_params_list, client_data_sizes, n_clusters=None, random_state=42):
"""
Aggregate VQ embeddings using weighted K-means clustering based on client data sizes

text
Args: client_vq_params_list: List of VQ parameter dictionaries from clients client_data_sizes: List of data sizes for each client (for weighting) n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility Returns: aggregated_vq: Aggregated VQ parameters """ aggregated_vq = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using weighted K-means with {n_clusters} clusters") # Create weighted samples based on client data sizes total_data_size = sum(client_data_sizes) client_weights = [size / total_data_size for size in client_data_sizes] # Create weighted embeddings weighted_embeddings = [] for i, embeddings in enumerate(embeddings_list): # Replicate embeddings based on client weight weight = client_weights[i] n_replicas = max(1, int(weight * n_clients * n_embeddings)) # Sample embeddings with replacement indices = np.random.choice(n_embeddings, size=n_replicas, replace=True) weighted_embeddings.append(embeddings[indices]) # Concatenate all weighted embeddings all_weighted_embeddings = np.concatenate(weighted_embeddings, axis=0) # Apply K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) kmeans.fit(all_weighted_embeddings) # Use cluster centers as the aggregated embeddings aggregated_embeddings = kmeans.cluster_centers_ # Ensure we have the correct number of embeddings if aggregated_embeddings.shape[0] != n_embeddings: logger.warning(f"Weighted K-means produced {aggregated_embeddings.shape[0]} clusters, but need {n_embeddings} embeddings") if aggregated_embeddings.shape[0] > n_embeddings: # Take the first n_embeddings clusters aggregated_embeddings = aggregated_embeddings[:n_embeddings] else: # Pad with random embeddings if we have fewer clusters remaining = n_embeddings - aggregated_embeddings.shape[0] padding = np.random.randn(remaining, embedding_dim) * 0.1 aggregated_embeddings = np.concatenate([aggregated_embeddings, padding], axis=0) aggregated_vq[param_name] = torch.tensor(aggregated_embeddings, dtype=torch.float32) logger.info(f"Aggregated {param_name}: {aggregated_embeddings.shape}") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq

--- Enhanced K-means with Multiple Initializations ---

def enhanced_kmeans_clustering_vq(client_vq_params_list, client_usage_data,
n_clusters=None, random_state=42, n_init=20):
"""
Enhanced K-means clustering with codebook usage tracking

text
Args: client_vq_params_list: List of VQ parameter dictionaries from clients client_usage_data: List of (used_indices, usage_counts) tuples for each client n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility n_init: Number of random initializations Returns: aggregated_vq: Aggregated VQ parameters aggregation_stats: Statistics about the aggregation process """ aggregated_vq = {} aggregation_stats = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using usage-aware enhanced K-means") # Collect only used embeddings from all clients used_embeddings = [] used_indices_global = set() client_contributions = [] for client_id, (embeddings, (used_indices, usage_counts)) in enumerate(zip(embeddings_list, client_usage_data)): used_indices_global.update(used_indices) # Extract only used embeddings from this client if used_indices: client_used_embeddings = embeddings[list(used_indices)] used_embeddings.append(client_used_embeddings) client_contributions.append(len(used_indices)) logger.info(f"Client {client_id+1}: Using {len(used_indices)} out of {n_embeddings} embeddings") else: logger.warning(f"Client {client_id+1}: No used embeddings found, using all embeddings") used_embeddings.append(embeddings) client_contributions.append(n_embeddings) # Combine all used embeddings if used_embeddings: all_used_embeddings = np.concatenate(used_embeddings, axis=0) logger.info(f"Total used embeddings for clustering: {all_used_embeddings.shape[0]}") else: # Fallback: use all embeddings if no usage data all_used_embeddings = np.concatenate(embeddings_list, axis=0) logger.warning("No usage data available, using all embeddings") # Apply K-means clustering with enhanced parameters actual_clusters = min(n_clusters, all_used_embeddings.shape[0]) kmeans = KMeans( n_clusters=actual_clusters, random_state=random_state, n_init=n_init, init='k-means++', max_iter=300, tol=1e-4 ) try: kmeans.fit(all_used_embeddings) cluster_centers = kmeans.cluster_centers_ inertia = kmeans.inertia_ logger.info(f"K-means completed: {actual_clusters} clusters, inertia: {inertia:.4f}") except Exception as e: logger.error(f"K-means clustering failed: {e}") # Fallback to simple averaging cluster_centers = np.mean(embeddings_list, axis=0) inertia = 0.0 # Create final embedding matrix if cluster_centers.shape[0] != n_embeddings: if cluster_centers.shape[0] > n_embeddings: # Take the first n_embeddings clusters final_embeddings = cluster_centers[:n_embeddings] else: # Pad with unused embeddings from clients or random initialization final_embeddings = np.zeros((n_embeddings, embedding_dim)) final_embeddings[:cluster_centers.shape[0]] = cluster_centers # Fill remaining positions with averaged unused embeddings remaining = n_embeddings - cluster_centers.shape[0] if remaining > 0: # Get unused embeddings from clients unused_embeddings = [] for client_id, (embeddings, (used_indices, _)) in enumerate(zip(embeddings_list, client_usage_data)): all_indices = set(range(n_embeddings)) unused_indices = all_indices - used_indices if unused_indices: unused_embeddings.append(embeddings[list(unused_indices)]) if unused_embeddings: all_unused = np.concatenate(unused_embeddings, axis=0) if all_unused.shape[0] >= remaining: # Use K-means on unused embeddings unused_kmeans = KMeans(n_clusters=remaining, random_state=random_state, n_init=5) unused_centers = unused_kmeans.fit(all_unused).cluster_centers_ final_embeddings[cluster_centers.shape[0]:] = unused_centers else: # Use all unused embeddings and pad with noise final_embeddings[cluster_centers.shape[0]:cluster_centers.shape[0]+all_unused.shape[0]] = all_unused # Fill remaining with small random noise around used embeddings if cluster_centers.shape[0] > 0: noise_base = np.mean(cluster_centers, axis=0) noise = np.random.randn(remaining - all_unused.shape[0], embedding_dim) * 0.01 final_embeddings[cluster_centers.shape[0]+all_unused.shape[0]:] = noise_base + noise else: # No unused embeddings available, use random initialization final_embeddings[cluster_centers.shape[0]:] = np.random.randn(remaining, embedding_dim) * 0.1 else: final_embeddings = cluster_centers aggregated_vq[param_name] = torch.tensor(final_embeddings, dtype=torch.float32) # Store aggregation statistics aggregation_stats[param_name] = { 'used_embeddings_count': len(used_indices_global), 'total_embeddings': n_embeddings, 'usage_ratio': len(used_indices_global) / n_embeddings, 'client_contributions': client_contributions, 'clustering_inertia': inertia, 'final_shape': final_embeddings.shape } logger.info(f"Aggregated {param_name}: {final_embeddings.shape}, " f"Used {len(used_indices_global)}/{n_embeddings} embeddings " f"({100*len(used_indices_global)/n_embeddings:.1f}%)") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq, aggregation_stats

--- Parameter Selection Utilities ---

def get_shared_parameters(model):
vq_params, decoder_params, classifier_params = {}, {}, {}
for name, param in model.named_parameters():
if 'vq_layer' in name:
vq_params[name] = param.data.clone()
elif 'decoder' in name:
decoder_params[name] = param.data.clone()
elif 'classifier' in name or 'avgpool' in name:
classifier_params[name] = param.data.clone()
return vq_params, decoder_params, classifier_params
def track_codebook_usage(model, data_loader, threshold=0.01):
"""
Track which codebook entries are being used by a model

text
Args: model: The VQ model data_loader: DataLoader to run inference on threshold: Minimum usage frequency to consider an embedding as "used" Returns: used_indices: Set of indices of used codebook entries usage_counts: Dictionary mapping codebook index to usage count """ model.eval() codebook_size = model.vq_layer.num_embeddings usage_counts = np.zeros(codebook_size) total_tokens = 0 with torch.no_grad(): for images, _ in data_loader: images = images.to(DEVICE) # Get VQ layer output features = model.encoder(images) B, C, H, W = features.shape # Get logits and softmax probabilities logits = model.vq_layer.proj(features).permute(0, 2, 3, 1).reshape(-1, codebook_size) probs = F.softmax(logits, dim=-1) # Count usage based on max probability (which embedding is most likely chosen) max_indices = torch.argmax(probs, dim=-1) # Count occurrences for idx in max_indices.cpu().numpy(): usage_counts[idx] += 1 total_tokens += len(max_indices) # Calculate usage frequencies usage_frequencies = usage_counts / total_tokens # Find used indices based on threshold used_indices = set(np.where(usage_frequencies >= threshold)[0]) logger.info(f"Codebook usage: {len(used_indices)}/{codebook_size} embeddings used " f"(threshold: {threshold}, max_freq: {usage_frequencies.max():.4f})") return used_indices, dict(enumerate(usage_counts))

def update_shared_parameters(model, vq_params, decoder_params, classifier_params):
state_dict = model.state_dict()
for name, param in vq_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))
for name, param in decoder_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))
for name, param in classifier_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))

--- Smart Checkpoint Management ---

def save_best_checkpoint(client_id, round_num, model_state, val_acc):
"""Save checkpoint only if it's better than previous best"""
global client_best_val_accs, client_best_checkpoints

text
os.makedirs("checkpoints", exist_ok=True) if val_acc > client_best_val_accs[client_id]: # Delete old checkpoint if it exists if client_best_checkpoints[client_id] is not None: old_checkpoint = client_best_checkpoints[client_id] if os.path.exists(old_checkpoint): os.remove(old_checkpoint) logger.info(f"Removed old checkpoint: {old_checkpoint}") # Save new best checkpoint new_checkpoint = f"14th-july-Enhanced-kmeans-perplexity_client_{client_id+1}_best.pth" torch.save(model_state, new_checkpoint) # Update tracking client_best_val_accs[client_id] = val_acc client_best_checkpoints[client_id] = new_checkpoint logger.info(f"Client {client_id+1}: New best checkpoint saved with val_acc: {val_acc:.2f}%") return True else: logger.info(f"Client {client_id+1}: Val acc {val_acc:.2f}% not better than best {client_best_val_accs[client_id]:.2f}%") return False

--- Enhanced Table Display Functions ---

def print_global_test_table(round_num, global_test_accs, train_percentages):
"""Print global test accuracies in tabular format with client percentages"""
logger.info(f"\n{'='*90}")
logger.info(f"GLOBAL TEST ACCURACIES - ROUND {round_num + 1}")
logger.info(f"{'='*90}")

text
# Header with percentages header = "| Client |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Average |" logger.info(header) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Accuracy row acc_row = "| Acc(%) |" for acc in global_test_accs: acc_row += f" {acc:5.2f} |" avg_acc = np.mean(global_test_accs) acc_row += f" {avg_acc:6.2f} |" logger.info(acc_row) logger.info(f"{'='*90}") return avg_acc

def print_consolidated_table(round_num, all_round_data):
"""Print consolidated table every 5 rounds"""
logger.info(f"\n{'='*100}")
logger.info(f"CONSOLIDATED GLOBAL TEST ACCURACIES - ROUNDS 1 to {round_num + 1}")
logger.info(f"{'='*100}")

text
# Header header = "| Round |" + " ".join([f" {i+1:2d} |" for i in range(NUM_CLIENTS)]) + " Average |" logger.info(header) logger.info("|" + "-" * 7 + "|" + "-" * 6 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Show last 5 rounds or all if less than 5 start_round = max(0, round_num - 4) for r in range(start_round, round_num + 1): if r < len(all_round_data): round_data = all_round_data[r] acc_row = f"| {r+1:5d} |" for acc in round_data['accs']: acc_row += f"{acc:5.2f} |" acc_row += f" {round_data['avg']:6.2f} |" logger.info(acc_row) logger.info(f"{'='*100}") # Show best performance in this window if all_round_data: window_data = all_round_data[start_round:round_num+1] best_avg = max(data['avg'] for data in window_data) best_round = next(r for r, data in enumerate(window_data, start_round) if data['avg'] == best_avg) logger.info(f"Best Average in Window: {best_avg:.2f}% (Round {best_round + 1})") logger.info(f"{'='*100}")

def evaluate_class_wise_accuracy(model, test_loader, num_classes=NUM_CLASSES):
"""Evaluate class-wise accuracy for a model"""
model.eval()
class_correct = [0] * num_classes
class_total = [0] * num_classes

text
with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) with autocast(): outputs, _, _, _, _ = model(images) _, preds = torch.max(outputs, 1) # Calculate class-wise accuracy for i in range(labels.size(0)): label = labels[i].item() class_total[label] += 1 if preds[i] == label: class_correct[label] += 1 # Calculate accuracy for each class class_accuracies = [] for i in range(num_classes): if class_total[i] > 0: accuracy = 100 * class_correct[i] / class_total[i] class_accuracies.append(accuracy) else: class_accuracies.append(0.0) # No samples for this class return class_accuracies, class_total

def print_class_wise_accuracy_table(round_num, client_models, global_test_loader, train_percentages):
"""Print class-wise accuracy table for all clients"""
logger.info(f"\n{'='*120}")
logger.info(f"CLASS-WISE GLOBAL TEST ACCURACIES - ROUND {round_num + 1}")
logger.info(f"{'='*120}")

text
# Evaluate class-wise accuracy for each client all_client_class_accs = [] all_client_class_totals = [] for client_id in range(NUM_CLIENTS): class_accs, class_totals = evaluate_class_wise_accuracy(client_models[client_id], global_test_loader) all_client_class_accs.append(class_accs) all_client_class_totals.append(class_totals) # Create header with client percentages header = "| Class |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Average |" logger.info(header) logger.info("|" + "-" * 7 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Print accuracy for each class for class_id in range(NUM_CLASSES): class_row = f"| {class_id:5d} |" class_accs_for_avg = [] for client_id in range(NUM_CLIENTS): acc = all_client_class_accs[client_id][class_id] class_row += f" {acc:5.2f} |" class_accs_for_avg.append(acc) avg_acc = np.mean(class_accs_for_avg) class_row += f" {avg_acc:6.2f} |" logger.info(class_row) logger.info(f"{'='*120}") # Calculate and print overall average overall_client_avgs = [] for client_id in range(NUM_CLIENTS): client_avg = np.mean(all_client_class_accs[client_id]) overall_client_avgs.append(client_avg) overall_avg = np.mean(overall_client_avgs) # Print overall average row logger.info("OVERALL AVERAGE:") avg_row = "| Avg |" for client_avg in overall_client_avgs: avg_row += f" {client_avg:5.2f} |" avg_row += f" {overall_avg:6.2f} |" logger.info(avg_row) logger.info(f"{'='*120}") return overall_client_avgs, overall_avg

--- Visualization Functions ---

def plot_client_data_distribution(train_indices, val_indices, full_train_dataset, full_val_dataset):
"""Plot histogram showing class distribution for each client"""

text
# Get class distributions for all clients train_distributions = [] val_distributions = [] for client_id in range(NUM_CLIENTS): train_dist = get_class_distribution(full_train_dataset, train_indices[client_id]) val_dist = get_class_distribution(full_val_dataset, val_indices[client_id]) # Convert to arrays for plotting train_counts = [train_dist.get(i, 0) for i in range(NUM_CLASSES)] val_counts = [val_dist.get(i, 0) for i in range(NUM_CLASSES)] train_distributions.append(train_counts) val_distributions.append(val_counts) # Create subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) # Plot training data distribution x = np.arange(NUM_CLIENTS) width = 0.8 bottom_train = np.zeros(NUM_CLIENTS) colors = plt.cm.tab10(np.linspace(0, 1, NUM_CLASSES)) for class_id in range(NUM_CLASSES): class_counts = [train_distributions[client][class_id] for client in range(NUM_CLIENTS)] ax1.bar(x, class_counts, width, bottom=bottom_train, label=f'Class {class_id}', color=colors[class_id]) bottom_train += class_counts ax1.set_xlabel('Client ID') ax1.set_ylabel('Number of Samples') ax1.set_title('Training Data Distribution per Client') ax1.set_xticks(x) ax1.set_xticklabels([f'Client {i+1}' for i in range(NUM_CLIENTS)]) ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Plot validation data distribution bottom_val = np.zeros(NUM_CLIENTS) for class_id in range(NUM_CLASSES): class_counts = [val_distributions[client][class_id] for client in range(NUM_CLIENTS)] ax2.bar(x, class_counts, width, bottom=bottom_val, label=f'Class {class_id}', color=colors[class_id]) bottom_val += class_counts ax2.set_xlabel('Client ID') ax2.set_ylabel('Number of Samples') ax2.set_title('Validation Data Distribution per Client') ax2.set_xticks(x) ax2.set_xticklabels([f'Client {i+1}' for i in range(NUM_CLIENTS)]) ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() plt.savefig('client_data_distribution.png', dpi=300, bbox_inches='tight') plt.show() logger.info("Data distribution histogram saved as 'client_data_distribution.png'")

def print_enhanced_client_distribution_table(train_indices, val_indices, full_train_dataset, full_val_dataset):
"""Print enhanced distribution table with percentages"""

text
# Calculate percentages train_percentages, val_percentages = calculate_client_data_percentages(train_indices, val_indices) logger.info("=" * 120) logger.info("ENHANCED CLIENT-WISE DATA DISTRIBUTION WITH PERCENTAGES") logger.info("=" * 120) # Create header with percentages header = "| Client |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Total |" logger.info(header) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 7 + "|") # Print training data distribution logger.info("TRAINING DATA:") for class_id in range(NUM_CLASSES): class_row = f"| Cls {class_id:2d} |" class_total = 0 for client_id in range(NUM_CLIENTS): train_dist = get_class_distribution(full_train_dataset, train_indices[client_id]) count = train_dist.get(class_id, 0) class_row += f" {count:4d} |" class_total += count class_row += f" {class_total:4d} |" logger.info(class_row) # Print total training samples per client total_row = "| Total |" grand_total = 0 for client_id in range(NUM_CLIENTS): client_total = len(train_indices[client_id]) total_row += f" {client_total:4d} |" grand_total += client_total total_row += f" {grand_total:4d} |" logger.info(total_row) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 7 + "|") # Print validation data distribution logger.info("VALIDATION DATA:") for class_id in range(NUM_CLASSES): class_row = f"| Cls {class_id:2d} |" class_total = 0 for client_id in range(NUM_CLIENTS): val_dist = get_class_distribution(full_val_dataset, val_indices[client_id]) count = val_dist.get(class_id, 0) class_row += f" {count:4d} |" class_total += count class_row += f" {class_total:4d} |" logger.info(class_row) # Print total validation samples per client total_row = "| Total |" grand_total = 0 for client_id in range(NUM_CLIENTS): client_total = len(val_indices[client_id]) total_row += f" {client_total:4d} |" grand_total += client_total total_row += f" {grand_total:4d} |" logger.info(total_row) logger.info("=" * 120)

Print enhanced distribution table with percentages

print_enhanced_client_distribution_table(train_client_indices, val_client_indices, full_train_dataset, full_val_dataset)

Create and save histogram visualization

plot_client_data_distribution(train_client_indices, val_client_indices, full_train_dataset, full_val_dataset)

--- Training and Evaluation Functions ---

def train_client(model, train_loader, val_loader, client_id, round_num):
"""Train a single client using AMP with iteration-based training and usage tracking"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

text
best_val_acc = 0 best_model_state = None model.to(DEVICE) for epoch in range(LOCAL_EPOCHS): model.train() total_loss, correct, total, total_perplexity = 0, 0, 0, 0 # Now all clients train for the same number of iterations for batch_idx, (images, labels) in enumerate(train_loader): images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) optimizer.zero_grad() with autocast(): outputs, vq_loss, decoded, original, perplexity = model(images) recon_loss = F.mse_loss(decoded, original.detach()) loss = criterion(outputs, labels) + 0.05 * vq_loss + 0.1 * recon_loss scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() * images.size(0) total_perplexity += perplexity.item() * images.size(0) _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) train_acc = 100 * correct / total train_loss = total_loss / total avg_perplexity = total_perplexity / total # Validation model.eval() val_correct, val_total, val_loss = 0, 0, 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) with autocast(): outputs, vq_loss, decoded, original, _ = model(images) recon_loss = F.mse_loss(decoded, original.detach()) loss = criterion(outputs, labels) + 0.05 * vq_loss + 0.1 * recon_loss val_loss += loss.item() * images.size(0) _, preds = torch.max(outputs, 1) val_correct += (preds == labels).sum().item() val_total += labels.size(0) val_acc = 100 * val_correct / val_total val_loss = val_loss / val_total logger.info(f"Client {client_id+1} | Round {round_num+1} | Epoch {epoch+1} | " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | " f"Perplexity: {avg_perplexity:.2f}") if val_acc > best_val_acc: best_val_acc = val_acc best_model_state = deepcopy(model.state_dict()) # Load best model state if best_model_state is not None: model.load_state_dict(best_model_state) # Use smart checkpoint management save_best_checkpoint(client_id, round_num, best_model_state, best_val_acc) # Track codebook usage after training model.eval() used_indices, usage_counts = track_codebook_usage(model, train_loader, threshold=0.01) return model, best_val_acc, (used_indices, usage_counts)

def evaluate_global_test(model, global_test_loader):
"""Evaluate a model on global test dataset"""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in global_test_loader:
images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE)
with autocast():
outputs, _, _, _, _ = model(images)
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return 100 * correct / total

--- Part 3: Main Federated Learning Loop with Enhanced Logging ---

--- Main Federated Learning Loop ---

logger.info("\n" + "="*80)
logger.info("STARTING FEDERATED LEARNING")
logger.info("="*80)

Initialize models per client

client_models = [VQDecoderResNetClassifier().to(DEVICE) for _ in range(NUM_CLIENTS)]

Track performance across all rounds

best_client_accs = [0.0] * NUM_CLIENTS # Best global test accuracy for each client across all rounds
best_round_for_client = [0] * NUM_CLIENTS # Round where best accuracy was achieved
all_round_data = [] # Store all round data for consolidated tables

how does my codebook aggregation is happening

Thought

Explaining codebook aggregation

The user's codebook aggregation involves several functions: federated_averaging for general parameters, and Wasserstein barycenter/kmeans clustering for embeddings. The specific processes include concatenating model embeddings from each client, applying KMeans clustering to find cluster centers, and using these centers as aggregated embeddings. Non-embedding parameters are averaged. The enhanced or weighted versions factor in sizes and usage to improve clustering. So, codebook aggregation happens via a combination of KMeans and averaging methods, likely tailored to fit how the code’s designed.

Thought for 4 seconds

Answer

Your code aggregates the VQ “codebook” (i.e. the learned embedding table) by treating it as a clustering problem over all clients’ embeddings, rather than by simply averaging. Here’s how it works step by step:


1. Identify the embedding parameters

  • In each client’s model, any parameter whose name contains "embeddings" is treated as part of the codebook.
  • All other parameters (e.g. decoder and classifier weights) are aggregated by simple averaging via torch.mean.

2. Concatenate embeddings from all clients

  • You pull out each client’s embeddings matrix (shape [n_embeddings, embedding_dim]), convert to NumPy, and stack them into one big array of shape [num_clients × n_embeddings, embedding_dim].

3. Run K-Means clustering

  • You call KMeans(n_clusters=n_embeddings) on that concatenated array.
  • The idea is: if each client had its own slightly different codebook, clustering will find the most representative centers across all of them.

4. Use the cluster centers as the new codebook

  • The cluster_centers_ array (shape [n_embeddings, embedding_dim]) becomes your aggregated codebook.
  • You convert it back to a torch.Tensor and assign it to the global model’s codebook parameter.

5. Handle shape mismatches

  • If K-Means returns more or fewer centers than n_embeddings, you either truncate or pad with random/noise-based vectors to maintain the correct size.

Variants you’ve also implemented

  1. Weighted K-Means (weighted_kmeans_clustering_vq):

    • Replicates each client’s embeddings in proportion to that client’s data size before clustering, so larger clients have more influence.
  2. Usage-Aware K-Means (enhanced_kmeans_clustering_vq):

    • First runs each client’s model over its data to see which codebook entries actually get used.
    • Clusters only the used embeddings (plus a fallback to unused ones or noise if needed), tracks clustering inertia, and then pads/truncates to form the final codebook.

In summary

Rather than averaging every codebook vector across clients, you treat all clients’ embeddings as a point cloud and re-derive the codebook via clustering, using the resulting cluster centers as the new shared codebook. All non-codebook parameters remain federated-averaged in the usual way. This lets your global codebook better capture the diversity of features learned by each client.

Question

what is vq embeddings?

Thought

Thought for 4 seconds

Answer

In the context of your code (and more broadly in models like VQ-VAE or VQ-GAN), “VQ embeddings” refers to the discrete set of prototype vectors—often called the codebook—that your continuous latent representations get quantized against. Here’s what that means:

  1. Continuous → Discrete

    • Your encoder (e.g. the ResNet-18 trunk) produces a dense feature tensor zez_e of shape [B,C,H,W][B, C, H, W].
    • Instead of passing zez_e directly to the decoder, you “quantize” each spatial vector to the nearest entry in a fixed-size lookup table {c1,c2,,cK}\{c_1, c_2, …, c_K\}. Those lookup-table vectors are your VQ embeddings (or code-vectors). Wikipedia
  2. The Codebook

    • You maintain a learnable matrix ERK×D\mathbf{E} \in \mathbb{R}^{K \times D}, where KK is the number of embeddings (codebook size) and DD is the embedding dimensionality.
    • During the forward pass, each flattened feature vector ze(i)RDz_e(i)\in\mathbb{R}^D is replaced by its closest code-vector cjc_j.
    • This yields a quantized tensor zqz_q that feeds into your decoder. Milvus
  3. Why Quantize?

    • Compression & Discreteness: By forcing features onto a discrete set, you reduce the infinite real-valued latent space to a finite one—making downstream tasks (generation, compression, retrieval) often simpler.
    • Stabilized Training: Discrete codes can regularize representations, sometimes improving sample quality in generative models (e.g. in DALL·E or VQ-GAN).
  4. Learning the Embeddings

    • You don’t pick the codebook entries by hand—they’re learned via a combination of reconstruction losses and a “commitment” or “quantization” loss (in your case, MSE between zqz_q and zez_e).
    • The Gumbel-softmax trick in your GumbelSoftmaxVQ layer makes this discrete selection differentiable, so you can backpropagate through the “argmin” step.
  5. Aggregation in Federated Setting

    • Since each client ends up with its own codebook matrix E(i)\mathbf{E}^{(i)}, you need a way to merge them. Instead of simply averaging, your K-means aggregation treats all clients’ embeddings as points in RD\mathbb{R}^D, re-clusters them into KK centers, and uses those centers as the new shared codebook.

In short:

  • VQ embeddings = the rows of your model’s learnable codebook matrix.
  • They serve as the finite set of prototypes that continuous features are snapped to during quantization.
  • Aggregating them via clustering (rather than averaging) lets you discover a common vocabulary of latent codes across clients.

Share this Q&A