# --- Part 1: Imports, Constants, Dataset Loading,...
Создано: 14 июля 2025 г.
Отвечено с использованием O4 Mini High от Chat01
Создано: 14 июля 2025 г.
Отвечено с использованием O4 Mini High от Chat01
from sklearn.cluster import KMeans
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, Dataset
import numpy as np
from collections import defaultdict, Counter
from copy import deepcopy
import logging
from torchvision import transforms
from torchvision.models import resnet18, ResNet18_Weights
from medmnist import INFO, PathMNIST
from medmnist import dict as medmnist_dict
from torch.cuda.amp import GradScaler, autocast
import ot # For Wasserstein barycenter
import json
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s : %(message)s',
handlers=[logging.FileHandler('14th-july-Enhanced-kmeans-perplexity.log'), logging.StreamHandler()]
)
logger = logging.getLogger(name)
def get_class_distribution(dataset, indices):
"""Get class distribution for given indices"""
labels = [dataset[i][1] for i in indices] # no .item()
return Counter(labels)
BATCH_SIZE = 64
NUM_CLASSES = 9
NUM_CLIENTS = 10
LOCAL_EPOCHS = 5
COMMUNICATION_ROUNDS = 100
ALPHA = 0.5
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {DEVICE}")
logger.info("Using ImageNet pre-trained ResNet18 weights for feature extraction")
strong_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.3),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
standard_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
])
info = INFO['pathmnist']
DataClass = medmnist_dict[info['python_class']]
from PIL import Image
train_ds = DataClass(split='train', download=True)
val_ds = DataClass(split='val', download=True)
combined_imgs = np.concatenate([train_ds.imgs, val_ds.imgs], axis=0)
combined_labels = np.concatenate([train_ds.labels, val_ds.labels], axis=0).squeeze()
class CombinedDataset(torch.utils.data.Dataset):
def init(self, imgs, labels, transform=None):
self.imgs = imgs
self.labels = labels
self.transform = transform
textdef __len__(self): return len(self.labels) def __getitem__(self, idx): img = Image.fromarray(self.imgs[idx]) label = int(self.labels[idx]) if self.transform: img = self.transform(img) return img, label
full_combined_dataset = CombinedDataset(combined_imgs, combined_labels, transform=None)
full_train_dataset=full_combined_dataset
full_val_dataset=full_combined_dataset
full_test_dataset = DataClass(split='test', transform=transform_test, download=True)
class AugmentedDataset(Dataset):
def init(self, base_dataset, indices, transform, target_iterations, batch_size):
self.base_dataset = base_dataset
self.indices = indices
self.transform = transform
self.target_size = target_iterations * batch_size
self.actual_size = len(indices)
text# Calculate how many times we need to repeat the dataset self.repeat_factor = max(1, self.target_size // self.actual_size) if self.target_size % self.actual_size > 0: self.repeat_factor += 1 logger.info(f"Dataset with {self.actual_size} samples expanded to {self.target_size} samples (repeat factor: {self.repeat_factor})") def __len__(self): return self.target_size def __getitem__(self, idx): # Map the index to the original dataset original_idx = self.indices[idx % self.actual_size] image, label = self.base_dataset[original_idx] # Apply transform if self.transform: image = self.transform(image) return image, label
def dirichlet_partition(dataset, alpha, num_clients, min_samples_per_class=20):
labels = np.array([int(dataset[i][1]) for i in range(len(dataset))])
num_classes = np.max(labels) + 1
class_indices = [np.where(labels == y)[0] for y in range(num_classes)]
client_indices = [[] for _ in range(num_clients)]
textfor c, indices in enumerate(class_indices): np.random.shuffle(indices) # Check if we have enough samples for minimum allocation required_samples = min_samples_per_class * num_clients if len(indices) < required_samples: logger.warning(f"Class {c} has only {len(indices)} samples, but needs {required_samples}. Distributing available samples equally.") # Distribute available samples as equally as possible samples_per_client = len(indices) // num_clients remainder = len(indices) % num_clients start_idx = 0 for client_id in range(num_clients): end_idx = start_idx + samples_per_client + (1 if client_id < remainder else 0) client_indices[client_id].extend(indices[start_idx:end_idx].tolist()) start_idx = end_idx continue # Step 1: First distribute minimum samples to each client remaining_indices = list(indices) for client_id in range(num_clients): allocated_samples = remaining_indices[:min_samples_per_class] client_indices[client_id].extend(allocated_samples) remaining_indices = remaining_indices[min_samples_per_class:] # Step 2: Apply Dirichlet distribution to remaining samples if len(remaining_indices) > 0: proportions = np.random.dirichlet([alpha] * num_clients) # Convert proportions to actual sample counts sample_counts = (proportions * len(remaining_indices)).astype(int) # Handle rounding errors - distribute remaining samples remaining_to_distribute = len(remaining_indices) - sample_counts.sum() if remaining_to_distribute > 0: # Add remaining samples to clients with highest proportions top_clients = np.argsort(proportions)[-remaining_to_distribute:] for client_id in top_clients: sample_counts[client_id] += 1 # Distribute the remaining samples according to Dirichlet proportions start_idx = 0 for client_id in range(num_clients): end_idx = start_idx + sample_counts[client_id] if end_idx > start_idx: # Only add if there are samples to add client_indices[client_id].extend(remaining_indices[start_idx:end_idx]) start_idx = end_idx return client_indices
global_test_dataset = full_test_dataset
global_test_loader = DataLoader(global_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
#train_client_indices = dirichlet_partition(full_train_dataset, alpha=ALPHA, num_clients=NUM_CLIENTS, min_samples_per_class=20)
#val_client_indices = dirichlet_partition(full_val_dataset, alpha=ALPHA, num_clients=NUM_CLIENTS, min_samples_per_class=5)
'''logger.info("Loading client indices from JSON files...")
with open('train_client_indices.json', 'r') as f:
train_client_indices = json.load(f)
with open('val_client_indices.json', 'r') as f:
val_client_indices = json.load(f)'''
logger.info("Loading client indices from JSON files...")
with open('0.5_dirichlet_pathmnist/train_client_indices.json', 'r') as f:
train_client_indices_raw = json.load(f)
with open('0.5_dirichlet_pathmnist/val_client_indices.json', 'r') as f:
val_client_indices_raw = json.load(f)
train_client_indices = []
val_client_indices = []
if isinstance(train_client_indices_raw, dict):
for client_id in range(NUM_CLIENTS):
client_key = str(client_id) # JSON keys are strings
if client_key in train_client_indices_raw:
train_client_indices.append(train_client_indices_raw[client_key])
else:
logger.error(f"Client {client_id} not found in train_client_indices.json")
raise KeyError(f"Client {client_id} not found in JSON")
else:
# If JSON is already a list of lists
train_client_indices = train_client_indices_raw
if isinstance(val_client_indices_raw, dict):
for client_id in range(NUM_CLIENTS):
client_key = str(client_id)
if client_key in val_client_indices_raw:
val_client_indices.append(val_client_indices_raw[client_key])
else:
logger.error(f"Client {client_id} not found in val_client_indices.json")
raise KeyError(f"Client {client_id} not found in JSON")
else:
val_client_indices = val_client_indices_raw
logger.info(f"Loaded train indices for {len(train_client_indices)} clients")
logger.info(f"Loaded val indices for {len(val_client_indices)} clients")
for client_id in range(NUM_CLIENTS):
logger.info(f"Client {client_id + 1}: Train samples = {len(train_client_indices[client_id])}, Val samples = {len(val_client_indices[client_id])}")
logger.info(f"Loaded train indices for {len(train_client_indices)} clients")
logger.info(f"Loaded val indices for {len(val_client_indices)} clients")
max_train_size = max(len(indices) for indices in train_client_indices)
max_val_size = max(len(indices) for indices in val_client_indices)
max_train_iterations = max_train_size // BATCH_SIZE
max_val_iterations = max_val_size // BATCH_SIZE
logger.info(f"Maximum training dataset size: {max_train_size}")
logger.info(f"Maximum validation dataset size: {max_val_size}")
logger.info(f"Target iterations per epoch: {max_train_iterations} (train), {max_val_iterations} (val)")
client_datasets = []
for i in range(NUM_CLIENTS):
train_size = len(train_client_indices[i])
val_size = len(val_client_indices[i])
text# Determine if client needs strong augmentation (smaller datasets) if train_size < max_train_size * 0.7: # If less than 70% of max size train_transform = strong_transform logger.info(f"Client {i+1}: Using strong augmentation (dataset size: {train_size})") else: train_transform = standard_transform logger.info(f"Client {i+1}: Using standard augmentation (dataset size: {train_size})") # Create augmented datasets train_dataset = AugmentedDataset( full_train_dataset, train_client_indices[i], train_transform, max_train_iterations, BATCH_SIZE ) val_dataset = AugmentedDataset( full_val_dataset, val_client_indices[i], transform_test, max_val_iterations, BATCH_SIZE ) client_datasets.append({ 'train': train_dataset, 'val': val_dataset, 'train_loader': DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True), 'val_loader': DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True) })
def calculate_client_data_percentages(train_indices, val_indices):
"""Calculate percentage of total data each client has"""
texttotal_train_samples = sum(len(indices) for indices in train_indices) total_val_samples = sum(len(indices) for indices in val_indices) train_percentages = [] val_percentages = [] for client_id in range(NUM_CLIENTS): train_pct = (len(train_indices[client_id]) / total_train_samples) * 100 val_pct = (len(val_indices[client_id]) / total_val_samples) * 100 train_percentages.append(train_pct) val_percentages.append(val_pct) return train_percentages, val_percentages
#train_percentages, val_percentages = calculate_client_data_percentages(train_client_indices, val_client_indices)
logger.info("\n" + "=" * 80)
logger.info("INDIVIDUAL CLIENT DISTRIBUTIONS")
logger.info("=" * 80)
for client_id in range(NUM_CLIENTS):
train_dist = get_class_distribution(full_train_dataset, train_client_indices[client_id])
val_dist = get_class_distribution(full_val_dataset, val_client_indices[client_id])
logger.info(f"Client {client_id + 1} Train Dist: {dict(train_dist)} (Original: {len(train_client_indices[client_id])}, Augmented: {len(client_datasets[client_id]['train'])})")
logger.info(f"Client {client_id + 1} Val Dist: {dict(val_dist)} (Original: {len(val_client_indices[client_id])}, Augmented: {len(client_datasets[client_id]['val'])})")
train_percentages, val_percentages = calculate_client_data_percentages(train_client_indices, val_client_indices)
logger.info("\n" + "=" * 80)
logger.info("CLIENT DATA PERCENTAGES")
logger.info("=" * 80)
for client_id in range(NUM_CLIENTS):
logger.info(f"Client {client_id + 1}: Train {train_percentages[client_id]:.1f}%, Val {val_percentages[client_id]:.1f}%")
global_test_labels = [full_test_dataset[i][1].item() for i in range(len(full_test_dataset))]
global_test_dist = Counter(global_test_labels)
logger.info(f"\nGlobal Test Distribution (Total: {len(full_test_dataset)}): {dict(global_test_dist)}")
client_best_val_accs = [0.0] * NUM_CLIENTS # Track best validation accuracy for each client
client_best_checkpoints = [None] * NUM_CLIENTS # Track best checkpoint paths for each client
class GumbelSoftmaxVQ(nn.Module):
def init(self, num_embeddings=512, embedding_dim=512, temp_init=1.0):
super().init()
self.embedding_dim = embedding_dim
self.num_embeddings = num_embeddings
self.embeddings = nn.Parameter(torch.randn(num_embeddings, embedding_dim))
self.proj = nn.Conv2d(embedding_dim, num_embeddings, 1)
self.temperature = temp_init
textdef forward(self, x): B, C, H, W = x.shape logits = self.proj(x).permute(0, 2, 3, 1).reshape(-1, self.num_embeddings) soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, hard=True) z_q = soft_one_hot @ self.embeddings z_q = z_q.view(B, H, W, C).permute(0, 3, 1, 2) avg_probs = soft_one_hot.mean(dim=0) entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-10)) perplexity = torch.exp(entropy) commitment_loss = F.mse_loss(z_q.detach(), x) return z_q, commitment_loss, perplexity
class VQDecoderResNetClassifier(nn.Module):
def init(self, num_classes=NUM_CLASSES, embedding_dim=512):
super().init()
# Use ImageNet pre-trained weights
self.resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
self.encoder = nn.Sequential(*list(self.resnet.children())[:-2])
self.vq_layer = GumbelSoftmaxVQ(num_embeddings=512, embedding_dim=embedding_dim)
textself.decoder_conv1 = nn.Sequential( nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1), nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True) ) self.decoder_conv2 = nn.Sequential( nn.Conv2d(embedding_dim, embedding_dim, kernel_size=3, padding=1), nn.BatchNorm2d(embedding_dim), nn.ReLU(inplace=True) ) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.classifier = nn.Linear(embedding_dim, num_classes) def forward(self, x): features = self.encoder(x) quantized, vq_loss, perplexity = self.vq_layer(features) x1 = self.decoder_conv1(quantized) x2 = self.decoder_conv2(x1 + quantized) decoded = x2 + x1 pooled = self.avgpool(decoded) pooled = torch.flatten(pooled, 1) output = self.classifier(pooled) return output, vq_loss, decoded, features, perplexity
def federated_averaging(client_params_list, param_names):
aggregated_params = {}
for param_name in param_names:
stacked_params = torch.stack([params[param_name] for params in client_params_list])
aggregated_params[param_name] = torch.mean(stacked_params, dim=0)
return aggregated_params
def kmeans_clustering_vq(client_vq_params_list, n_clusters=None, random_state=42):
"""
Aggregate VQ embeddings using K-means clustering
textArgs: client_vq_params_list: List of VQ parameter dictionaries from clients n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility Returns: aggregated_vq: Aggregated VQ parameters """ aggregated_vq = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using K-means with {n_clusters} clusters") # Concatenate all embeddings from all clients all_embeddings = np.concatenate(embeddings_list, axis=0) # Shape: (n_clients * n_embeddings, embedding_dim) # Apply K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) kmeans.fit(all_embeddings) # Use cluster centers as the aggregated embeddings aggregated_embeddings = kmeans.cluster_centers_ # Ensure we have the correct number of embeddings if aggregated_embeddings.shape[0] != n_embeddings: logger.warning(f"K-means produced {aggregated_embeddings.shape[0]} clusters, but need {n_embeddings} embeddings") if aggregated_embeddings.shape[0] > n_embeddings: # Take the first n_embeddings clusters aggregated_embeddings = aggregated_embeddings[:n_embeddings] else: # Pad with random embeddings if we have fewer clusters remaining = n_embeddings - aggregated_embeddings.shape[0] padding = np.random.randn(remaining, embedding_dim) * 0.1 aggregated_embeddings = np.concatenate([aggregated_embeddings, padding], axis=0) aggregated_vq[param_name] = torch.tensor(aggregated_embeddings, dtype=torch.float32) logger.info(f"Aggregated {param_name}: {aggregated_embeddings.shape}") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq
def weighted_kmeans_clustering_vq(client_vq_params_list, client_data_sizes, n_clusters=None, random_state=42):
"""
Aggregate VQ embeddings using weighted K-means clustering based on client data sizes
textArgs: client_vq_params_list: List of VQ parameter dictionaries from clients client_data_sizes: List of data sizes for each client (for weighting) n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility Returns: aggregated_vq: Aggregated VQ parameters """ aggregated_vq = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using weighted K-means with {n_clusters} clusters") # Create weighted samples based on client data sizes total_data_size = sum(client_data_sizes) client_weights = [size / total_data_size for size in client_data_sizes] # Create weighted embeddings weighted_embeddings = [] for i, embeddings in enumerate(embeddings_list): # Replicate embeddings based on client weight weight = client_weights[i] n_replicas = max(1, int(weight * n_clients * n_embeddings)) # Sample embeddings with replacement indices = np.random.choice(n_embeddings, size=n_replicas, replace=True) weighted_embeddings.append(embeddings[indices]) # Concatenate all weighted embeddings all_weighted_embeddings = np.concatenate(weighted_embeddings, axis=0) # Apply K-means clustering kmeans = KMeans(n_clusters=n_clusters, random_state=random_state, n_init=10) kmeans.fit(all_weighted_embeddings) # Use cluster centers as the aggregated embeddings aggregated_embeddings = kmeans.cluster_centers_ # Ensure we have the correct number of embeddings if aggregated_embeddings.shape[0] != n_embeddings: logger.warning(f"Weighted K-means produced {aggregated_embeddings.shape[0]} clusters, but need {n_embeddings} embeddings") if aggregated_embeddings.shape[0] > n_embeddings: # Take the first n_embeddings clusters aggregated_embeddings = aggregated_embeddings[:n_embeddings] else: # Pad with random embeddings if we have fewer clusters remaining = n_embeddings - aggregated_embeddings.shape[0] padding = np.random.randn(remaining, embedding_dim) * 0.1 aggregated_embeddings = np.concatenate([aggregated_embeddings, padding], axis=0) aggregated_vq[param_name] = torch.tensor(aggregated_embeddings, dtype=torch.float32) logger.info(f"Aggregated {param_name}: {aggregated_embeddings.shape}") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq
def enhanced_kmeans_clustering_vq(client_vq_params_list, client_usage_data,
n_clusters=None, random_state=42, n_init=20):
"""
Enhanced K-means clustering with codebook usage tracking
textArgs: client_vq_params_list: List of VQ parameter dictionaries from clients client_usage_data: List of (used_indices, usage_counts) tuples for each client n_clusters: Number of clusters (if None, uses the embedding dimension) random_state: Random state for reproducibility n_init: Number of random initializations Returns: aggregated_vq: Aggregated VQ parameters aggregation_stats: Statistics about the aggregation process """ aggregated_vq = {} aggregation_stats = {} for param_name in client_vq_params_list[0].keys(): if 'embeddings' in param_name: # Extract embeddings from all clients embeddings_list = [params[param_name].cpu().numpy() for params in client_vq_params_list] # Get dimensions n_clients = len(embeddings_list) n_embeddings, embedding_dim = embeddings_list[0].shape # Set number of clusters if not specified if n_clusters is None: n_clusters = n_embeddings logger.info(f"Aggregating {param_name} using usage-aware enhanced K-means") # Collect only used embeddings from all clients used_embeddings = [] used_indices_global = set() client_contributions = [] for client_id, (embeddings, (used_indices, usage_counts)) in enumerate(zip(embeddings_list, client_usage_data)): used_indices_global.update(used_indices) # Extract only used embeddings from this client if used_indices: client_used_embeddings = embeddings[list(used_indices)] used_embeddings.append(client_used_embeddings) client_contributions.append(len(used_indices)) logger.info(f"Client {client_id+1}: Using {len(used_indices)} out of {n_embeddings} embeddings") else: logger.warning(f"Client {client_id+1}: No used embeddings found, using all embeddings") used_embeddings.append(embeddings) client_contributions.append(n_embeddings) # Combine all used embeddings if used_embeddings: all_used_embeddings = np.concatenate(used_embeddings, axis=0) logger.info(f"Total used embeddings for clustering: {all_used_embeddings.shape[0]}") else: # Fallback: use all embeddings if no usage data all_used_embeddings = np.concatenate(embeddings_list, axis=0) logger.warning("No usage data available, using all embeddings") # Apply K-means clustering with enhanced parameters actual_clusters = min(n_clusters, all_used_embeddings.shape[0]) kmeans = KMeans( n_clusters=actual_clusters, random_state=random_state, n_init=n_init, init='k-means++', max_iter=300, tol=1e-4 ) try: kmeans.fit(all_used_embeddings) cluster_centers = kmeans.cluster_centers_ inertia = kmeans.inertia_ logger.info(f"K-means completed: {actual_clusters} clusters, inertia: {inertia:.4f}") except Exception as e: logger.error(f"K-means clustering failed: {e}") # Fallback to simple averaging cluster_centers = np.mean(embeddings_list, axis=0) inertia = 0.0 # Create final embedding matrix if cluster_centers.shape[0] != n_embeddings: if cluster_centers.shape[0] > n_embeddings: # Take the first n_embeddings clusters final_embeddings = cluster_centers[:n_embeddings] else: # Pad with unused embeddings from clients or random initialization final_embeddings = np.zeros((n_embeddings, embedding_dim)) final_embeddings[:cluster_centers.shape[0]] = cluster_centers # Fill remaining positions with averaged unused embeddings remaining = n_embeddings - cluster_centers.shape[0] if remaining > 0: # Get unused embeddings from clients unused_embeddings = [] for client_id, (embeddings, (used_indices, _)) in enumerate(zip(embeddings_list, client_usage_data)): all_indices = set(range(n_embeddings)) unused_indices = all_indices - used_indices if unused_indices: unused_embeddings.append(embeddings[list(unused_indices)]) if unused_embeddings: all_unused = np.concatenate(unused_embeddings, axis=0) if all_unused.shape[0] >= remaining: # Use K-means on unused embeddings unused_kmeans = KMeans(n_clusters=remaining, random_state=random_state, n_init=5) unused_centers = unused_kmeans.fit(all_unused).cluster_centers_ final_embeddings[cluster_centers.shape[0]:] = unused_centers else: # Use all unused embeddings and pad with noise final_embeddings[cluster_centers.shape[0]:cluster_centers.shape[0]+all_unused.shape[0]] = all_unused # Fill remaining with small random noise around used embeddings if cluster_centers.shape[0] > 0: noise_base = np.mean(cluster_centers, axis=0) noise = np.random.randn(remaining - all_unused.shape[0], embedding_dim) * 0.01 final_embeddings[cluster_centers.shape[0]+all_unused.shape[0]:] = noise_base + noise else: # No unused embeddings available, use random initialization final_embeddings[cluster_centers.shape[0]:] = np.random.randn(remaining, embedding_dim) * 0.1 else: final_embeddings = cluster_centers aggregated_vq[param_name] = torch.tensor(final_embeddings, dtype=torch.float32) # Store aggregation statistics aggregation_stats[param_name] = { 'used_embeddings_count': len(used_indices_global), 'total_embeddings': n_embeddings, 'usage_ratio': len(used_indices_global) / n_embeddings, 'client_contributions': client_contributions, 'clustering_inertia': inertia, 'final_shape': final_embeddings.shape } logger.info(f"Aggregated {param_name}: {final_embeddings.shape}, " f"Used {len(used_indices_global)}/{n_embeddings} embeddings " f"({100*len(used_indices_global)/n_embeddings:.1f}%)") else: # For non-embedding parameters, use standard averaging stacked = torch.stack([params[param_name] for params in client_vq_params_list]) aggregated_vq[param_name] = torch.mean(stacked, dim=0) return aggregated_vq, aggregation_stats
def get_shared_parameters(model):
vq_params, decoder_params, classifier_params = {}, {}, {}
for name, param in model.named_parameters():
if 'vq_layer' in name:
vq_params[name] = param.data.clone()
elif 'decoder' in name:
decoder_params[name] = param.data.clone()
elif 'classifier' in name or 'avgpool' in name:
classifier_params[name] = param.data.clone()
return vq_params, decoder_params, classifier_params
def track_codebook_usage(model, data_loader, threshold=0.01):
"""
Track which codebook entries are being used by a model
textArgs: model: The VQ model data_loader: DataLoader to run inference on threshold: Minimum usage frequency to consider an embedding as "used" Returns: used_indices: Set of indices of used codebook entries usage_counts: Dictionary mapping codebook index to usage count """ model.eval() codebook_size = model.vq_layer.num_embeddings usage_counts = np.zeros(codebook_size) total_tokens = 0 with torch.no_grad(): for images, _ in data_loader: images = images.to(DEVICE) # Get VQ layer output features = model.encoder(images) B, C, H, W = features.shape # Get logits and softmax probabilities logits = model.vq_layer.proj(features).permute(0, 2, 3, 1).reshape(-1, codebook_size) probs = F.softmax(logits, dim=-1) # Count usage based on max probability (which embedding is most likely chosen) max_indices = torch.argmax(probs, dim=-1) # Count occurrences for idx in max_indices.cpu().numpy(): usage_counts[idx] += 1 total_tokens += len(max_indices) # Calculate usage frequencies usage_frequencies = usage_counts / total_tokens # Find used indices based on threshold used_indices = set(np.where(usage_frequencies >= threshold)[0]) logger.info(f"Codebook usage: {len(used_indices)}/{codebook_size} embeddings used " f"(threshold: {threshold}, max_freq: {usage_frequencies.max():.4f})") return used_indices, dict(enumerate(usage_counts))
def update_shared_parameters(model, vq_params, decoder_params, classifier_params):
state_dict = model.state_dict()
for name, param in vq_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))
for name, param in decoder_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))
for name, param in classifier_params.items():
if name in state_dict:
state_dict[name].copy_(param.to(DEVICE))
def save_best_checkpoint(client_id, round_num, model_state, val_acc):
"""Save checkpoint only if it's better than previous best"""
global client_best_val_accs, client_best_checkpoints
textos.makedirs("checkpoints", exist_ok=True) if val_acc > client_best_val_accs[client_id]: # Delete old checkpoint if it exists if client_best_checkpoints[client_id] is not None: old_checkpoint = client_best_checkpoints[client_id] if os.path.exists(old_checkpoint): os.remove(old_checkpoint) logger.info(f"Removed old checkpoint: {old_checkpoint}") # Save new best checkpoint new_checkpoint = f"14th-july-Enhanced-kmeans-perplexity_client_{client_id+1}_best.pth" torch.save(model_state, new_checkpoint) # Update tracking client_best_val_accs[client_id] = val_acc client_best_checkpoints[client_id] = new_checkpoint logger.info(f"Client {client_id+1}: New best checkpoint saved with val_acc: {val_acc:.2f}%") return True else: logger.info(f"Client {client_id+1}: Val acc {val_acc:.2f}% not better than best {client_best_val_accs[client_id]:.2f}%") return False
def print_global_test_table(round_num, global_test_accs, train_percentages):
"""Print global test accuracies in tabular format with client percentages"""
logger.info(f"\n{'='*90}")
logger.info(f"GLOBAL TEST ACCURACIES - ROUND {round_num + 1}")
logger.info(f"{'='*90}")
text# Header with percentages header = "| Client |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Average |" logger.info(header) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Accuracy row acc_row = "| Acc(%) |" for acc in global_test_accs: acc_row += f" {acc:5.2f} |" avg_acc = np.mean(global_test_accs) acc_row += f" {avg_acc:6.2f} |" logger.info(acc_row) logger.info(f"{'='*90}") return avg_acc
def print_consolidated_table(round_num, all_round_data):
"""Print consolidated table every 5 rounds"""
logger.info(f"\n{'='*100}")
logger.info(f"CONSOLIDATED GLOBAL TEST ACCURACIES - ROUNDS 1 to {round_num + 1}")
logger.info(f"{'='*100}")
text# Header header = "| Round |" + " ".join([f" {i+1:2d} |" for i in range(NUM_CLIENTS)]) + " Average |" logger.info(header) logger.info("|" + "-" * 7 + "|" + "-" * 6 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Show last 5 rounds or all if less than 5 start_round = max(0, round_num - 4) for r in range(start_round, round_num + 1): if r < len(all_round_data): round_data = all_round_data[r] acc_row = f"| {r+1:5d} |" for acc in round_data['accs']: acc_row += f"{acc:5.2f} |" acc_row += f" {round_data['avg']:6.2f} |" logger.info(acc_row) logger.info(f"{'='*100}") # Show best performance in this window if all_round_data: window_data = all_round_data[start_round:round_num+1] best_avg = max(data['avg'] for data in window_data) best_round = next(r for r, data in enumerate(window_data, start_round) if data['avg'] == best_avg) logger.info(f"Best Average in Window: {best_avg:.2f}% (Round {best_round + 1})") logger.info(f"{'='*100}")
def evaluate_class_wise_accuracy(model, test_loader, num_classes=NUM_CLASSES):
"""Evaluate class-wise accuracy for a model"""
model.eval()
class_correct = [0] * num_classes
class_total = [0] * num_classes
textwith torch.no_grad(): for images, labels in test_loader: images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) with autocast(): outputs, _, _, _, _ = model(images) _, preds = torch.max(outputs, 1) # Calculate class-wise accuracy for i in range(labels.size(0)): label = labels[i].item() class_total[label] += 1 if preds[i] == label: class_correct[label] += 1 # Calculate accuracy for each class class_accuracies = [] for i in range(num_classes): if class_total[i] > 0: accuracy = 100 * class_correct[i] / class_total[i] class_accuracies.append(accuracy) else: class_accuracies.append(0.0) # No samples for this class return class_accuracies, class_total
def print_class_wise_accuracy_table(round_num, client_models, global_test_loader, train_percentages):
"""Print class-wise accuracy table for all clients"""
logger.info(f"\n{'='*120}")
logger.info(f"CLASS-WISE GLOBAL TEST ACCURACIES - ROUND {round_num + 1}")
logger.info(f"{'='*120}")
text# Evaluate class-wise accuracy for each client all_client_class_accs = [] all_client_class_totals = [] for client_id in range(NUM_CLIENTS): class_accs, class_totals = evaluate_class_wise_accuracy(client_models[client_id], global_test_loader) all_client_class_accs.append(class_accs) all_client_class_totals.append(class_totals) # Create header with client percentages header = "| Class |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Average |" logger.info(header) logger.info("|" + "-" * 7 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 9 + "|") # Print accuracy for each class for class_id in range(NUM_CLASSES): class_row = f"| {class_id:5d} |" class_accs_for_avg = [] for client_id in range(NUM_CLIENTS): acc = all_client_class_accs[client_id][class_id] class_row += f" {acc:5.2f} |" class_accs_for_avg.append(acc) avg_acc = np.mean(class_accs_for_avg) class_row += f" {avg_acc:6.2f} |" logger.info(class_row) logger.info(f"{'='*120}") # Calculate and print overall average overall_client_avgs = [] for client_id in range(NUM_CLIENTS): client_avg = np.mean(all_client_class_accs[client_id]) overall_client_avgs.append(client_avg) overall_avg = np.mean(overall_client_avgs) # Print overall average row logger.info("OVERALL AVERAGE:") avg_row = "| Avg |" for client_avg in overall_client_avgs: avg_row += f" {client_avg:5.2f} |" avg_row += f" {overall_avg:6.2f} |" logger.info(avg_row) logger.info(f"{'='*120}") return overall_client_avgs, overall_avg
def plot_client_data_distribution(train_indices, val_indices, full_train_dataset, full_val_dataset):
"""Plot histogram showing class distribution for each client"""
text# Get class distributions for all clients train_distributions = [] val_distributions = [] for client_id in range(NUM_CLIENTS): train_dist = get_class_distribution(full_train_dataset, train_indices[client_id]) val_dist = get_class_distribution(full_val_dataset, val_indices[client_id]) # Convert to arrays for plotting train_counts = [train_dist.get(i, 0) for i in range(NUM_CLASSES)] val_counts = [val_dist.get(i, 0) for i in range(NUM_CLASSES)] train_distributions.append(train_counts) val_distributions.append(val_counts) # Create subplots fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) # Plot training data distribution x = np.arange(NUM_CLIENTS) width = 0.8 bottom_train = np.zeros(NUM_CLIENTS) colors = plt.cm.tab10(np.linspace(0, 1, NUM_CLASSES)) for class_id in range(NUM_CLASSES): class_counts = [train_distributions[client][class_id] for client in range(NUM_CLIENTS)] ax1.bar(x, class_counts, width, bottom=bottom_train, label=f'Class {class_id}', color=colors[class_id]) bottom_train += class_counts ax1.set_xlabel('Client ID') ax1.set_ylabel('Number of Samples') ax1.set_title('Training Data Distribution per Client') ax1.set_xticks(x) ax1.set_xticklabels([f'Client {i+1}' for i in range(NUM_CLIENTS)]) ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Plot validation data distribution bottom_val = np.zeros(NUM_CLIENTS) for class_id in range(NUM_CLASSES): class_counts = [val_distributions[client][class_id] for client in range(NUM_CLIENTS)] ax2.bar(x, class_counts, width, bottom=bottom_val, label=f'Class {class_id}', color=colors[class_id]) bottom_val += class_counts ax2.set_xlabel('Client ID') ax2.set_ylabel('Number of Samples') ax2.set_title('Validation Data Distribution per Client') ax2.set_xticks(x) ax2.set_xticklabels([f'Client {i+1}' for i in range(NUM_CLIENTS)]) ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left') plt.tight_layout() plt.savefig('client_data_distribution.png', dpi=300, bbox_inches='tight') plt.show() logger.info("Data distribution histogram saved as 'client_data_distribution.png'")
def print_enhanced_client_distribution_table(train_indices, val_indices, full_train_dataset, full_val_dataset):
"""Print enhanced distribution table with percentages"""
text# Calculate percentages train_percentages, val_percentages = calculate_client_data_percentages(train_indices, val_indices) logger.info("=" * 120) logger.info("ENHANCED CLIENT-WISE DATA DISTRIBUTION WITH PERCENTAGES") logger.info("=" * 120) # Create header with percentages header = "| Client |" for i in range(NUM_CLIENTS): header += f" {i+1:2d}({train_percentages[i]:4.1f}%) |" header += " Total |" logger.info(header) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 7 + "|") # Print training data distribution logger.info("TRAINING DATA:") for class_id in range(NUM_CLASSES): class_row = f"| Cls {class_id:2d} |" class_total = 0 for client_id in range(NUM_CLIENTS): train_dist = get_class_distribution(full_train_dataset, train_indices[client_id]) count = train_dist.get(class_id, 0) class_row += f" {count:4d} |" class_total += count class_row += f" {class_total:4d} |" logger.info(class_row) # Print total training samples per client total_row = "| Total |" grand_total = 0 for client_id in range(NUM_CLIENTS): client_total = len(train_indices[client_id]) total_row += f" {client_total:4d} |" grand_total += client_total total_row += f" {grand_total:4d} |" logger.info(total_row) logger.info("|" + "-" * 8 + "|" + "-" * 11 * NUM_CLIENTS + "|" + "-" * 7 + "|") # Print validation data distribution logger.info("VALIDATION DATA:") for class_id in range(NUM_CLASSES): class_row = f"| Cls {class_id:2d} |" class_total = 0 for client_id in range(NUM_CLIENTS): val_dist = get_class_distribution(full_val_dataset, val_indices[client_id]) count = val_dist.get(class_id, 0) class_row += f" {count:4d} |" class_total += count class_row += f" {class_total:4d} |" logger.info(class_row) # Print total validation samples per client total_row = "| Total |" grand_total = 0 for client_id in range(NUM_CLIENTS): client_total = len(val_indices[client_id]) total_row += f" {client_total:4d} |" grand_total += client_total total_row += f" {grand_total:4d} |" logger.info(total_row) logger.info("=" * 120)
print_enhanced_client_distribution_table(train_client_indices, val_client_indices, full_train_dataset, full_val_dataset)
plot_client_data_distribution(train_client_indices, val_client_indices, full_train_dataset, full_val_dataset)
def train_client(model, train_loader, val_loader, client_id, round_num):
"""Train a single client using AMP with iteration-based training and usage tracking"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()
textbest_val_acc = 0 best_model_state = None model.to(DEVICE) for epoch in range(LOCAL_EPOCHS): model.train() total_loss, correct, total, total_perplexity = 0, 0, 0, 0 # Now all clients train for the same number of iterations for batch_idx, (images, labels) in enumerate(train_loader): images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) optimizer.zero_grad() with autocast(): outputs, vq_loss, decoded, original, perplexity = model(images) recon_loss = F.mse_loss(decoded, original.detach()) loss = criterion(outputs, labels) + 0.05 * vq_loss + 0.1 * recon_loss scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() total_loss += loss.item() * images.size(0) total_perplexity += perplexity.item() * images.size(0) _, preds = torch.max(outputs, 1) correct += (preds == labels).sum().item() total += labels.size(0) train_acc = 100 * correct / total train_loss = total_loss / total avg_perplexity = total_perplexity / total # Validation model.eval() val_correct, val_total, val_loss = 0, 0, 0 with torch.no_grad(): for images, labels in val_loader: images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE) with autocast(): outputs, vq_loss, decoded, original, _ = model(images) recon_loss = F.mse_loss(decoded, original.detach()) loss = criterion(outputs, labels) + 0.05 * vq_loss + 0.1 * recon_loss val_loss += loss.item() * images.size(0) _, preds = torch.max(outputs, 1) val_correct += (preds == labels).sum().item() val_total += labels.size(0) val_acc = 100 * val_correct / val_total val_loss = val_loss / val_total logger.info(f"Client {client_id+1} | Round {round_num+1} | Epoch {epoch+1} | " f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | " f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}% | " f"Perplexity: {avg_perplexity:.2f}") if val_acc > best_val_acc: best_val_acc = val_acc best_model_state = deepcopy(model.state_dict()) # Load best model state if best_model_state is not None: model.load_state_dict(best_model_state) # Use smart checkpoint management save_best_checkpoint(client_id, round_num, best_model_state, best_val_acc) # Track codebook usage after training model.eval() used_indices, usage_counts = track_codebook_usage(model, train_loader, threshold=0.01) return model, best_val_acc, (used_indices, usage_counts)
def evaluate_global_test(model, global_test_loader):
"""Evaluate a model on global test dataset"""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in global_test_loader:
images, labels = images.to(DEVICE), labels.squeeze().long().to(DEVICE)
with autocast():
outputs, _, _, _, _ = model(images)
_, preds = torch.max(outputs, 1)
correct += (preds == labels).sum().item()
total += labels.size(0)
return 100 * correct / total
logger.info("\n" + "="*80)
logger.info("STARTING FEDERATED LEARNING")
logger.info("="*80)
client_models = [VQDecoderResNetClassifier().to(DEVICE) for _ in range(NUM_CLIENTS)]
best_client_accs = [0.0] * NUM_CLIENTS # Best global test accuracy for each client across all rounds
best_round_for_client = [0] * NUM_CLIENTS # Round where best accuracy was achieved
all_round_data = [] # Store all round data for consolidated tables
how does my codebook aggregation is happening
Your code aggregates the VQ “codebook” (i.e. the learned embedding table) by treating it as a clustering problem over all clients’ embeddings, rather than by simply averaging. Here’s how it works step by step:
"embeddings"
is treated as part of the codebook.torch.mean
.embeddings
matrix (shape [n_embeddings, embedding_dim]
), convert to NumPy, and stack them into one big array of shape [num_clients × n_embeddings, embedding_dim]
.KMeans(n_clusters=n_embeddings)
on that concatenated array.cluster_centers_
array (shape [n_embeddings, embedding_dim]
) becomes your aggregated codebook.torch.Tensor
and assign it to the global model’s codebook parameter.n_embeddings
, you either truncate or pad with random/noise-based vectors to maintain the correct size.Weighted K-Means (weighted_kmeans_clustering_vq
):
Usage-Aware K-Means (enhanced_kmeans_clustering_vq
):
Rather than averaging every codebook vector across clients, you treat all clients’ embeddings as a point cloud and re-derive the codebook via clustering, using the resulting cluster centers as the new shared codebook. All non-codebook parameters remain federated-averaged in the usual way. This lets your global codebook better capture the diversity of features learned by each client.
what is vq embeddings?
In the context of your code (and more broadly in models like VQ-VAE or VQ-GAN), “VQ embeddings” refers to the discrete set of prototype vectors—often called the codebook—that your continuous latent representations get quantized against. Here’s what that means:
Continuous → Discrete
The Codebook
Why Quantize?
Learning the Embeddings
GumbelSoftmaxVQ
layer makes this discrete selection differentiable, so you can backpropagate through the “argmin” step.Aggregation in Federated Setting
In short: