현재 multi label multiclass분류 학습 파이프라인을 SOTA수준으로 아키텍...
생성일: 2025년 5월 3일
생성일: 2025년 5월 3일
현재 multi label multiclass분류 학습 파이프라인을 SOTA수준으로 아키텍처를 진보시켜야합니다:
python# train_hpo.py import argparse import os import random import warnings import logging import sys import pickle import yaml # For loading YAML config import functools # For passing args to objective from tqdm import tqdm import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader, default_collate from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score import optuna # --- TimeMixerPP Imports and Patches --- # Assume TimeMixerPP directory is in PYTHONPATH or current dir try: sys.path.append('./TimeMixerPP') # Ensure TimeMixerPP is findable from models.TimeMixerPP import Model as TimeMixerPP_Model, MixerBlock from layers.Embed import DataEmbedding_wo_pos from utils.tools import dotdict # Keep dotdict if used internally by TimeMixerPP from utils.metrics import metric import models.TimeMixerPP as timemixerpp_module except ImportError as e: print(f"Error importing TimeMixerPP modules: {e}") print("Please ensure the TimeMixerPP directory is in your Python path or run from the TimeMixerPP root.") sys.exit(1) # --- Logging Setup --- log_level = logging.INFO # Default, can be overridden by config later if needed logging.basicConfig(level=log_level, format='%(asctime)s - %(levelname)s - [%(process)d] %(message)s', stream=sys.stderr) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) # --- FFT Patch --- # --- (Patched FFT code - unchanged from previous version) --- try: original_FFT_for_Period = timemixerpp_module.FFT_for_Period except AttributeError: logging.warning("Could not find original FFT_for_Period function."); original_FFT_for_Period = None def patched_FFT_for_Period(x, k=2): if not isinstance(x, torch.Tensor): raise TypeError(f"Input x must be a torch.Tensor, got {type(x)}") original_dtype = x.dtype if x.dtype != torch.float32: x_float32 = x.to(torch.float32) else: x_float32 = x try: xf = torch.fft.rfft(x_float32, dim=1) except Exception as e: logging.error(f"Error during torch.fft.rfft: {e}"); logging.error(f"Input shape: {x_float32.shape}, dtype: {x_float32.dtype}, device: {x_float32.device}"); raise e if not torch.is_complex(xf): logging.warning(f"torch.fft.rfft output is not complex (dtype: {xf.dtype}). This might indicate an issue."); xf_abs = abs(xf) else: xf_abs = abs(xf) if xf_abs.numel() == 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) frequency_list = xf_abs.mean(0); if frequency_list.dim() > 1: frequency_list = frequency_list.mean(-1) frequency_list[0] = 0 actual_k = min(k, len(frequency_list)) if actual_k <= 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) try: _, top_list_indices = torch.topk(frequency_list, actual_k) except RuntimeError as e: logging.error(f"Error in torch.topk: {e}. frequency_list size: {len(frequency_list)}, k: {actual_k}") if actual_k > len(frequency_list): actual_k = len(frequency_list) if actual_k == 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) _, top_list_indices = torch.topk(frequency_list, actual_k) else: raise e top_list_indices_cpu = top_list_indices.detach().cpu().numpy() periods = np.zeros_like(top_list_indices_cpu, dtype=float) valid_indices_mask = top_list_indices_cpu != 0 safe_indices = top_list_indices_cpu[valid_indices_mask] if np.any(safe_indices == 0): logging.warning("Found zero index in FFT top list after initial filtering. Removing.") zero_mask = safe_indices != 0 safe_indices = safe_indices[zero_mask] valid_indices_mask = np.isin(top_list_indices_cpu, safe_indices) if np.any(valid_indices_mask): periods[valid_indices_mask] = x.shape[1] / top_list_indices_cpu[valid_indices_mask] valid_period_mask = periods > 0 top_list_filtered = top_list_indices_cpu[valid_period_mask]; periods_filtered = periods[valid_period_mask] periods_filtered_int = np.round(periods_filtered).astype(np.int64) non_zero_mask_after_round = periods_filtered_int > 0 periods_filtered_int = periods_filtered_int[non_zero_mask_after_round]; top_list_filtered = top_list_filtered[non_zero_mask_after_round] top_list_indices_tensor = torch.from_numpy(top_list_filtered).long().to(frequency_list.device) if top_list_indices_tensor.numel() > 0: weights_all_freqs = xf_abs if weights_all_freqs.dim() == 1: weights_all_freqs = weights_all_freqs.unsqueeze(0) if weights_all_freqs.dim() == 3: weights_all_freqs = weights_all_freqs.mean(-1) elif weights_all_freqs.dim() != 2: raise ValueError(f"Unexpected shape for weights_all_freqs: {weights_all_freqs.shape}") max_freq_index = weights_all_freqs.shape[1] - 1 valid_fft_indices_mask = top_list_indices_tensor <= max_freq_index valid_fft_indices = top_list_indices_tensor[valid_fft_indices_mask] if len(valid_fft_indices) != len(top_list_indices_tensor): logging.warning(f"Some top FFT indices out of bounds. Filtering.") periods_filtered_int = periods_filtered_int[valid_fft_indices_mask.cpu().numpy()] top_list_filtered = top_list_filtered[valid_fft_indices_mask.cpu().numpy()] top_list_indices_tensor = valid_fft_indices if top_list_indices_tensor.numel() > 0: period_weights = weights_all_freqs[:, top_list_indices_tensor] else: period_weights = torch.tensor([]).to(x.device).reshape(xf_abs.shape[0], 0) else: period_weights = torch.tensor([]).to(x.device).reshape(xf_abs.shape[0], 0) if not isinstance(periods_filtered_int, np.ndarray): raise TypeError("periods_filtered_int is not np.ndarray") if not isinstance(period_weights, torch.Tensor): raise TypeError("period_weights is not torch.Tensor") if not isinstance(top_list_filtered, np.ndarray): raise TypeError("top_list_filtered is not np.ndarray") return periods_filtered_int, period_weights, top_list_filtered if hasattr(timemixerpp_module, 'FFT_for_Period') and original_FFT_for_Period is not None: timemixerpp_module.FFT_for_Period = patched_FFT_for_Period logging.debug("Successfully patched 'models.TimeMixerPP.FFT_for_Period'.") elif not hasattr(timemixerpp_module, 'FFT_for_Period'): logging.warning("Could not find 'models.TimeMixerPP.FFT_for_Period'. Patching skipped.") else: logging.warning("Original FFT_for_Period not found during setup. Patching skipped.") # --- multi_reso_mixing Patch --- # --- (Patched multi_reso_mixing code - unchanged from previous version) --- try: original_multi_reso_mixing = MixerBlock.multi_reso_mixing except AttributeError: logging.warning("Could not find original MixerBlock.multi_reso_mixing method."); original_multi_reso_mixing = None def patched_multi_reso_mixing(self, period_weight, x, res): B_actual, T_actual, N_actual = x.size() if period_weight.shape[1] > 0: period_weight_softmax = F.softmax(period_weight, dim=1) period_weight_reshaped = period_weight_softmax.unsqueeze(1).unsqueeze(1) period_weight_repeated = period_weight_reshaped.repeat(1, T_actual, N_actual, 1) res_aggregated = torch.sum(res * period_weight_repeated, -1) else: res_aggregated = torch.zeros_like(x) res_final = res_aggregated + x return res_final if hasattr(MixerBlock, 'multi_reso_mixing') and original_multi_reso_mixing is not None: MixerBlock.multi_reso_mixing = patched_multi_reso_mixing logging.debug("Successfully patched 'models.TimeMixerPP.MixerBlock.multi_reso_mixing'.") elif not hasattr(MixerBlock, 'multi_reso_mixing'): logging.warning("Could not find 'models.TimeMixerPP.MixerBlock.multi_reso_mixing'. Patching skipped.") else: logging.warning("Original multi_reso_mixing not found during setup. Patching skipped.") # --- Utility Functions --- def control_randomness(seed: int = 42): random.seed(seed); os.environ["PYTHONHASHSEED"] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False; torch.cuda.manual_seed_all(seed) # --- Data Loading & Preprocessing (Takes config dict) --- def load_or_preprocess_data(config: dict): data_cfg = config['data_settings'] fixed_cfg = config['fixed_params'] cache_dir = data_cfg.get('cache_dir', './preprocess_cache') # Use get for optional cache cache_data_path = os.path.join(cache_dir, 'processed_data.pkl') cache_labels_path = os.path.join(cache_dir, 'processed_labels.pkl') cache_n_channels_path = os.path.join(cache_dir, 'n_channels.txt') os.makedirs(cache_dir, exist_ok=True) # --- (Data loading/preprocessing code - uses paths/params from config) --- if (data_cfg.get('use_cache', True) and # Check if caching enabled os.path.exists(cache_data_path) and os.path.exists(cache_labels_path) and os.path.exists(cache_n_channels_path)): logging.info(f"Loading preprocessed data from cache: {cache_dir}") try: with open(cache_data_path, 'rb') as f: all_data = pickle.load(f) with open(cache_labels_path, 'rb') as f: all_labels = pickle.load(f) with open(cache_n_channels_path, 'r') as f: actual_n_channels = int(f.read()) logging.info(f"Loaded {len(all_data)} samples from cache. Actual N_CHANNELS: {actual_n_channels}") # --- Add Validation: Check if cached seq_len matches config --- if len(all_data) > 0 and all_data[0].shape[0] != data_cfg['seq_len']: logging.warning(f"Cached data seq_len ({all_data[0].shape[0]}) != config seq_len ({data_cfg['seq_len']}). Reprocessing...") raise ValueError("Sequence length mismatch in cache.") # --- Add Validation: Check if cached channels matches config (if provided) --- # This is tricky as n_channels is usually *determined* by data, not set in config. # We'll rely on the preprocess function to determine it and save it. return all_data, all_labels, actual_n_channels except Exception as e: logging.warning(f"Failed to load data from cache: {e}. Reprocessing...") if os.path.exists(cache_data_path): os.remove(cache_data_path) if os.path.exists(cache_labels_path): os.remove(cache_labels_path) if os.path.exists(cache_n_channels_path): os.remove(cache_n_channels_path) logging.info("Cache not found, loading failed, or cache disabled. Starting data preprocessing...") all_data, all_labels, actual_n_channels = preprocess_data_timeseries(config) # Pass full config if data_cfg.get('use_cache', True): logging.info(f"Caching preprocessed data to {cache_dir}...") try: with open(cache_data_path, 'wb') as f: pickle.dump(all_data, f) with open(cache_labels_path, 'wb') as f: pickle.dump(all_labels, f) with open(cache_n_channels_path, 'w') as f: f.write(str(actual_n_channels)) logging.info("Successfully cached preprocessed data.") except Exception as e: logging.error(f"Failed to cache preprocessed data: {e}") return all_data, all_labels, actual_n_channels def preprocess_data_timeseries(config: dict): data_cfg = config['data_settings'] fixed_cfg = config['fixed_params'] target_seq_len = data_cfg['seq_len'] label_names = fixed_cfg['label_names'] time_series_path = data_cfg['train_path'] label_path = data_cfg['label_path'] # n_channels_config is not used here, it's determined from data logging.info(f"Loading raw data from {time_series_path} and {label_path}...") try: ts_df = pd.read_parquet(time_series_path) label_df = pd.read_csv(label_path) except FileNotFoundError as e: logging.error(f"Error loading raw data files: {e}"); raise # --- (Rest of preprocessing logic - unchanged, uses local vars derived from config) --- if 'timestamp' in ts_df.columns: ts_df['timestamp'] = pd.to_datetime(ts_df['timestamp']) ts_df['lifelog_date'] = ts_df['timestamp'].dt.date else: if 'lifelog_date' not in ts_df.columns: raise ValueError("Time series data must contain 'timestamp' or 'lifelog_date'.") ts_df['lifelog_date'] = pd.to_datetime(ts_df['lifelog_date']).dt.date if 'lifelog_date' in label_df.columns: label_df['lifelog_date'] = pd.to_datetime(label_df['lifelog_date']).dt.date else: raise ValueError("Label data must contain 'lifelog_date'.") potential_keys = ['timestamp', 'subject_id', 'lifelog_date'] sensor_cols = [col for col in ts_df.columns if col not in potential_keys] if not sensor_cols: raise ValueError("Could not automatically identify sensor columns.") actual_n_channels = len(sensor_cols) n_channels = actual_n_channels # Use actual channel count logging.info(f"Identified {n_channels} sensor columns: {sensor_cols[:5]}...") logging.info("Merging time series data with labels...") merge_keys = [] if 'subject_id' in ts_df.columns and 'subject_id' in label_df.columns: merge_keys.append('subject_id') if 'lifelog_date' in ts_df.columns and 'lifelog_date' in label_df.columns: merge_keys.append('lifelog_date') if not merge_keys: raise ValueError("Cannot merge dataframes. Missing common keys ('subject_id' and/or 'lifelog_date').") logging.info(f"Merging on keys: {merge_keys}") required_label_cols = merge_keys + label_names missing_label_cols = [col for col in required_label_cols if col not in label_df.columns] if missing_label_cols: raise ValueError(f"Label dataframe is missing required columns: {missing_label_cols}") merged_df = pd.merge(ts_df, label_df[required_label_cols], on=merge_keys, how='inner') logging.info(f"Merged dataframe shape: {merged_df.shape}") if merged_df.empty: raise ValueError("Merged dataframe is empty. Check merge keys and data content.") logging.info(f"Processing time series data per day for target length {target_seq_len}...") processed_data_list = [] processed_labels = [] grouping_keys = [key for key in merge_keys if key in merged_df.columns] if not grouping_keys: raise ValueError("Grouping keys not found after merge.") grouped = merged_df.groupby(grouping_keys) logging.info(f"Grouping data by: {grouping_keys}. Number of groups: {len(grouped)}") for name, group in tqdm(grouped, desc="Processing Days"): try: labels = group.iloc[0][label_names].values.astype(np.int64).tolist() except Exception as e: logging.error(f"Error converting labels for group {name}: {e}. Labels: {group.iloc[0][label_names].values}"); continue day_data = group[sensor_cols].fillna(0).values.astype(np.float32) current_len = day_data.shape[0] if current_len == 0: logging.warning(f"Skipping group {name} due to empty sensor data."); continue if current_len == target_seq_len: processed_day_data = day_data elif current_len < target_seq_len: padding_len = target_seq_len - current_len padding = np.zeros((padding_len, n_channels), dtype=np.float32) processed_day_data = np.concatenate([day_data, padding], axis=0) else: processed_day_data = day_data[-target_seq_len:, :] if processed_day_data.shape != (target_seq_len, n_channels): logging.error(f"Shape mismatch for group {name} after processing: expected {(target_seq_len, n_channels)}, got {processed_day_data.shape}. Skipping."); continue processed_data_list.append(processed_day_data) processed_labels.append(labels) logging.info(f"Finished processing. Found {len(processed_data_list)} valid samples.") if not processed_data_list: raise ValueError("No data processed. Check input files and processing logic.") labels_array = np.array(processed_labels, dtype=np.int64) return processed_data_list, labels_array, n_channels # --- Dataset Class --- class TimeSeriesDataset(Dataset): # --- (Dataset code - unchanged) --- def __init__(self, data_list, labels): self.data_list = data_list self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): series_data_np = self.data_list[idx] if series_data_np.dtype != np.float32: series_data_np = series_data_np.astype(np.float32) series_data = torch.from_numpy(series_data_np) label_vector_np = self.labels[idx] if label_vector_np.dtype != np.int64: label_vector_np = label_vector_np.astype(np.int64) label_vector = torch.from_numpy(label_vector_np) if not isinstance(series_data, torch.Tensor) or not isinstance(label_vector, torch.Tensor): raise TypeError(f"Dataset __getitem__ returned non-Tensor types for index {idx}. Got {type(series_data)} and {type(label_vector)}") return series_data, label_vector # --- Model Definitions --- class IndependentMLPHeads(nn.Module): # --- (MLP Head code - unchanged) --- def __init__(self, input_dim, num_classes_per_label, label_names, hidden_dim=512, dropout=0.1): super().__init__() self.heads = nn.ModuleDict() for i, name in enumerate(label_names): num_classes = num_classes_per_label[i] self.heads[name] = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) def forward(self, x): outputs = {} for name, head in self.heads.items(): outputs[name] = head(x) return outputs class TimeMixerPPClassifier(nn.Module): # --- (Classifier code - takes config dict for init) --- def __init__(self, model_config: dict, num_classes_per_label: list, label_names: list): super().__init__() self.configs = argparse.Namespace(**model_config) # Use Namespace for attribute access self.label_names = label_names self.timemixer_encoder = TimeMixerPP_Model(self.configs) if self.configs.channel_independence == 1: encoder_output_dim = self.configs.d_model else: encoder_output_dim = self.configs.d_model; logging.warning("Channel dependence used. Ensure aggregation handles multiple channels.") self.classifier_head = IndependentMLPHeads( input_dim=encoder_output_dim, num_classes_per_label=num_classes_per_label, label_names=label_names, hidden_dim=self.configs.mlp_hidden_dim, # Get from model_config dropout=self.configs.mlp_dropout # Get from model_config ) def forward(self, x_enc, x_mark_enc=None): # --- (Forward pass logic - unchanged) --- try: x_enc_list, x_mark_enc_list = self.timemixer_encoder._Model__multi_scale_process_inputs(x_enc, x_mark_enc) except AttributeError: logging.error("Could not find '_Model__multi_scale_process_inputs'."); raise except Exception as e: logging.error(f"Error during multi_scale_process_inputs: {e}"); raise x_processed_list = []; x_mark_processed_list = [] revin_layers = self.timemixer_encoder.revin_layers marks_to_process = x_mark_enc_list if x_mark_enc_list is None: marks_to_process = [None] * len(x_enc_list) elif len(x_enc_list) != len(x_mark_enc_list): raise ValueError(f"Mismatch lengths x_enc_list/x_mark_enc_list.") for i, (x, x_mark) in enumerate(zip(x_enc_list, marks_to_process)): if hasattr(self.configs, 'revin') and self.configs.revin == 1: if i < len(revin_layers): x = revin_layers[i](x, 'norm') else: logging.warning(f"RevIN layer index {i} out of bounds. Skipping.") if self.configs.channel_independence == 1: B, T_i, N_i = x.shape x = x.permute(0, 2, 1).contiguous().reshape(B * N_i, T_i, 1) if x_mark is not None: x_mark = x_mark.repeat_interleave(N_i, dim=0) x_processed_list.append(x); x_mark_processed_list.append(x_mark) if self.configs.channel_mixing and self.configs.channel_independence == 1: _, T_coarse, D_model = x_processed_list[-1].size() B_orig, N_orig = x_enc.shape[0], self.configs.enc_in coarse_scale_enc_out_flat = x_processed_list[-1].reshape(B_orig, N_orig, T_coarse * D_model) if hasattr(self.timemixer_encoder, 'channel_mixing_attention'): coarse_scale_enc_out_mixed, _ = self.timemixer_encoder.channel_mixing_attention(coarse_scale_enc_out_flat, coarse_scale_enc_out_flat, coarse_scale_enc_out_flat, None) x_processed_list[-1] = coarse_scale_enc_out_mixed.reshape(B_orig * N_orig, T_coarse, D_model) + x_processed_list[-1] else: logging.warning("Channel mixing enabled, but layer not found. Skipping.") enc_out_list = [] embedding_layer = self.timemixer_encoder.enc_embedding if not hasattr(embedding_layer, '__call__'): raise AttributeError("No callable 'enc_embedding'.") for x, x_mark in zip(x_processed_list, x_mark_processed_list): enc_out = embedding_layer(x, x_mark); enc_out_list.append(enc_out) processed_enc_out_list = enc_out_list encoder_layers = self.timemixer_encoder.encoder_model if not isinstance(encoder_layers, nn.ModuleList): raise AttributeError("No 'encoder_model' ModuleList.") for i in range(self.configs.e_layers): if i >= len(encoder_layers): logging.warning(f"Configured e_layers > actual layers. Stopping."); break layer = encoder_layers[i] if not callable(layer): raise TypeError(f"Encoder layer {i} not callable.") processed_enc_out_list = layer(processed_enc_out_list) if not isinstance(processed_enc_out_list, list): raise TypeError(f"MixerBlock {i} didn't return list.") if not processed_enc_out_list: raise ValueError(f"MixerBlock {i} returned empty list.") if not isinstance(processed_enc_out_list[0], torch.Tensor): raise TypeError(f"MixerBlock {i} list has non-Tensor.") if not processed_enc_out_list: raise ValueError("Encoder output list empty.") final_enc_out = processed_enc_out_list[0] if not isinstance(final_enc_out, torch.Tensor): raise TypeError(f"Final enc_out not Tensor.") aggregated_features = final_enc_out.mean(dim=1) if self.configs.channel_independence == 1: B_actual, D = aggregated_features.shape N_ch = self.configs.enc_in if N_ch == 0: raise ValueError("N_ch cannot be zero.") # --- Add check for division by zero if N_ch might be 0 --- if N_ch > 0 and B_actual % N_ch != 0: logging.error(f"Shape mismatch B_actual={B_actual}, N_ch={N_ch}. Check drop_last=True.") raise ValueError("Shape mismatch.") B_inferred = B_actual // N_ch if N_ch > 0 else B_actual # Avoid division by zero # Reshape only if N_ch > 0 and reshaping is needed if N_ch > 0 and B_actual != B_inferred: aggregated_features = aggregated_features.view(B_inferred, N_ch, D).mean(dim=1) # else: # If N_ch=0 or B_actual == B_inferred, use as is (or handle error) # pass # Or raise error if N_ch == 0 is unexpected outputs = self.classifier_head(aggregated_features) return outputs # --- Evaluation Metric --- def calculate_avg_macro_f1(all_true_labels, all_pred_logits, label_names): # --- (F1 calculation code - unchanged) --- f1_scores = {} num_samples = all_true_labels.shape[0] for i, name in enumerate(label_names): true_task = all_true_labels[:, i] pred_task_logits = all_pred_logits[name].cpu() pred_task_labels = torch.argmax(pred_task_logits, dim=1).numpy() true_task_np = true_task.cpu().numpy() if len(np.unique(true_task_np)) < 2 and len(np.unique(pred_task_labels)) < 2: f1 = 1.0 if np.array_equal(true_task_np, pred_task_labels) else 0.0 else: f1 = f1_score(true_task_np, pred_task_labels, average='macro', zero_division=0) f1_scores[name] = f1 avg_f1 = np.mean(list(f1_scores.values())) return avg_f1, f1_scores # --- Training and Evaluation Functions (Take config) --- criterion = nn.CrossEntropyLoss() def train_epoch_timemixer(model, device, train_dataloader, criterion, optimizer, scheduler, scaler, config: dict, current_epoch: int): # --- (Train epoch code - largely unchanged, uses config for params) --- model.train() total_loss = 0.0 optimizer.zero_grad() amp_enabled = scaler.is_enabled() label_names = config['fixed_params']['label_names'] accumulation_steps = config['hpo_params']['accumulation_steps'] # Get from HPO params now gradient_clip_val = config['hpo_params'].get('gradient_clip_val') # Optional clip val # No tqdm for HPO runs by default to reduce log spam # pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {current_epoch} Train", leave=False) for i, (batch_x, batch_labels) in enumerate(train_dataloader): batch_x = batch_x.to(device) batch_labels = batch_labels.to(device) batch_x_mark = None with autocast(device_type=device.type, dtype=torch.float16, enabled=amp_enabled): outputs = model(batch_x, batch_x_mark) batch_loss = 0.0 for task_idx, name in enumerate(label_names): loss = criterion(outputs[name], batch_labels[:, task_idx]) batch_loss += loss batch_loss /= len(label_names) loss_accum = batch_loss / accumulation_steps scaler.scale(loss_accum).backward() if (i + 1) % accumulation_steps == 0: if gradient_clip_val is not None and gradient_clip_val > 0: # Check > 0 scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if scheduler is not None and isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): scheduler.step() total_loss += batch_loss.item() # pbar.set_postfix({'loss': batch_loss.item()}) avg_loss = total_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0 if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): scheduler.step() return avg_loss def evaluate_epoch_timemixer(model, device, dataloader, criterion, config: dict, phase='val', current_epoch=0): # --- (Evaluate epoch code - largely unchanged, uses config) --- model.eval() all_true_labels = [] all_pred_logits = {name: [] for name in config['fixed_params']['label_names']} total_loss = 0.0 amp_enabled = (device.type == 'cuda') label_names = config['fixed_params']['label_names'] with torch.no_grad(): # pbar = tqdm(dataloader, desc=f"Epoch {current_epoch} Eval {phase}", leave=False) for batch_x, batch_labels in dataloader: batch_x = batch_x.to(device) batch_labels = batch_labels.to(device) batch_x_mark = None with autocast(device_type=device.type, dtype=torch.float16, enabled=amp_enabled): outputs = model(batch_x, batch_x_mark) batch_loss = 0.0 for i, name in enumerate(label_names): task_logits = outputs[name] loss = criterion(task_logits, batch_labels[:, i]) batch_loss += loss all_pred_logits[name].append(task_logits.cpu().float()) batch_loss /= len(label_names) total_loss += batch_loss.item() all_true_labels.append(batch_labels.cpu()) # pbar.set_postfix({'loss': batch_loss.item()}) avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0 all_true_labels_cat = torch.cat(all_true_labels, dim=0) if all_true_labels else torch.empty(0, len(label_names)) all_pred_logits_cat = {name: torch.cat(all_pred_logits[name], dim=0) for name in label_names} if all_pred_logits[label_names[0]] else {name: torch.empty(0) for name in label_names} # Handle case with no evaluation data if all_true_labels_cat.numel() == 0: logging.warning("Evaluation dataloader was empty. Returning F1=0.") return avg_loss, 0.0, {name: 0.0 for name in label_names} avg_f1, f1_scores_dict = calculate_avg_macro_f1(all_true_labels_cat, all_pred_logits_cat, label_names) return avg_loss, avg_f1, f1_scores_dict # --- Optuna Objective Function (Takes config) --- def suggest_hyperparameters(trial: optuna.trial.Trial, search_space_config: dict): """Suggests hyperparameters based on the search space config.""" params = {} for name, cfg in search_space_config.items(): suggest_type = cfg['type'] if suggest_type == 'categorical': params[name] = trial.suggest_categorical(name, cfg['choices']) elif suggest_type == 'int': params[name] = trial.suggest_int(name, cfg['low'], cfg['high'], step=cfg.get('step', 1), log=cfg.get('log', False)) elif suggest_type == 'float': params[name] = trial.suggest_float(name, cfg['low'], cfg['high'], step=cfg.get('step'), log=cfg.get('log', False)) else: raise ValueError(f"Unsupported suggestion type: {suggest_type} for parameter {name}") return params def objective(trial: optuna.trial.Trial, config: dict, preloaded_data: tuple, device: torch.device) -> float: """Optuna objective function driven by YAML config.""" # --- 1. Suggest Hyperparameters based on Config --- try: hpo_params = suggest_hyperparameters(trial, config['search_space']) # Combine suggested HPO params with fixed HPO params from config hpo_params.update(config.get('hpo_params', {})) # Adds batch_size, accumulation_steps etc. if defined there # Store combined params for the trial (useful for analysis) trial.set_user_attr("hpo_params", hpo_params) except Exception as e: logging.error(f"Error suggesting hyperparameters: {e}", exc_info=True) raise # Re-raise to mark trial as failed # --- 2. Setup Trial --- control_randomness(config['fixed_params']['seed'] + trial.number) # Seed per trial for reproducibility *if needed* all_data, all_labels, n_channels = preloaded_data if not all_data or all_labels is None or n_channels is None: logging.error("Preloaded data is missing or invalid!") return -1.0 # Indicate failure # --- 3. Data Splitting and Loaders --- indices = np.arange(len(all_data)) stratify_col = all_labels[:, 0] if all_labels.shape[1] > 0 and len(np.unique(all_labels[:, 0])) >= 2 else None train_idx, val_idx = train_test_split( indices, test_size=config['data_settings']['val_split_ratio'], random_state=config['fixed_params']['seed'], # Fixed seed for split consistency across trials stratify=stratify_col ) train_dataset = TimeSeriesDataset([all_data[i] for i in train_idx], all_labels[train_idx]) val_dataset = TimeSeriesDataset([all_data[i] for i in val_idx], all_labels[val_idx]) # Use hpo_params for batch size and accumulation steps batch_size = hpo_params['batch_size'] accumulation_steps = hpo_params['accumulation_steps'] num_workers = config['hpo_settings'].get('num_workers', min(os.cpu_count() // 2 if os.cpu_count() else 0, 4)) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(num_workers > 0 and device.type == 'cuda'), drop_last=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers, pin_memory=(num_workers > 0 and device.type == 'cuda'), drop_last=False ) # --- 4. Model Initialization --- # Combine fixed params and HPO params for model config model_config_dict = {**config['fixed_params'], **hpo_params} # HPO params override fixed ones if names clash model_config_dict['enc_in'] = n_channels # Use actual N channels model_config_dict['c_out'] = n_channels # Usually needed for TimeMixerPP init model_config_dict['seq_len'] = config['data_settings']['seq_len'] # Get from data settings model_config_dict['pred_len'] = 0 # Classification specific model_config_dict['label_len'] = 0 # Classification specific model_config_dict['task_name'] = 'classification' # Ensure d_ff is calculated if needed if 'd_ff_multiple' in hpo_params: model_config_dict['d_ff'] = model_config_dict['d_model'] * hpo_params['d_ff_multiple'] # Ensure MLP params are present in the main config dict # (They should already be there from the merge above if defined in search_space or fixed_params) if 'mlp_hidden_dim' not in model_config_dict: # Provide a default or raise error if it's mandatory logging.warning("mlp_hidden_dim not found in combined config, check YAML search_space/fixed_params.") # model_config_dict['mlp_hidden_dim'] = 256 # Example default raise ValueError("mlp_hidden_dim is required in the model configuration.") if 'mlp_dropout' not in model_config_dict: logging.warning("mlp_dropout not found in combined config, check YAML search_space/fixed_params.") # model_config_dict['mlp_dropout'] = 0.1 # Example default raise ValueError("mlp_dropout is required in the model configuration.") # Add potentially missing default TimeMixerPP args if not in config/hpo # These should be added to the main model_config_dict model_config_dict.setdefault('d_layers', 1) model_config_dict.setdefault('embed', 'timeF') model_config_dict.setdefault('freq', 'h') model_config_dict.setdefault('activation', 'gelu') model_config_dict.setdefault('output_attention', False) model_config_dict.setdefault('decomp_method', 'moving_avg') model_config_dict.setdefault('use_norm', 1) model_config_dict.setdefault('revin', 1) model_config_dict.setdefault('affine', 0) model_config_dict.setdefault('subtract_last', 0) model_config_dict.setdefault('decomposition', 0) model_config_dict.setdefault('kernel_size', 25) model_config_dict.setdefault('individual', 0) try: # Pass the combined dictionary as 'model_config' model = TimeMixerPPClassifier( model_config=model_config_dict, # <-- 수정: 올바른 인수 이름 사용 num_classes_per_label=config['fixed_params']['num_classes_per_label'], label_names=config['fixed_params']['label_names'] # <-- 수정: **classifier_params 제거 ) model.to(device) except Exception as e: logging.error(f"Error initializing model: {e}", exc_info=True) # 로그에 사용된 전체 설정 딕셔너리 출력 logging.error(f"Model Config used for init: {model_config_dict}") raise # Fail the trial # --- 5. Optimizer, Scheduler, Scaler --- optimizer = optim.AdamW(model.parameters(), lr=hpo_params['lr'], weight_decay=hpo_params['weight_decay']) epochs_per_trial = config['fixed_params']['epochs_per_trial'] num_optimizer_steps_per_epoch = len(train_loader) // accumulation_steps if len(train_loader) % accumulation_steps != 0: num_optimizer_steps_per_epoch +=1 if num_optimizer_steps_per_epoch == 0: num_optimizer_steps_per_epoch = 1 num_training_steps = epochs_per_trial * num_optimizer_steps_per_epoch # Optional: Make scheduler type/params configurable via YAML scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hpo_params['lr'], total_steps=num_training_steps, pct_start=hpo_params.get('pct_start', 0.1) # Get pct_start from HPO params or default ) scaler = GradScaler(device=device, enabled=(device.type == 'cuda')) # --- 6. Training Loop with Pruning --- best_val_f1 = -1.0 try: for epoch in range(epochs_per_trial): epoch_num = epoch + 1 train_loss = train_epoch_timemixer( model, device, train_loader, criterion, optimizer, scheduler, scaler, config, current_epoch=epoch_num # Pass full config ) val_loss, val_f1, val_f1_tasks = evaluate_epoch_timemixer( model, device, val_loader, criterion, config, phase='val', current_epoch=epoch_num # Pass config ) current_lr = optimizer.param_groups[0]['lr'] logging.debug(f'T{trial.number} E{epoch_num}/{epochs_per_trial} | LR: {current_lr:.2e} | TrL: {train_loss:.4f} | Vl L: {val_loss:.4f} | Vl F1: {val_f1:.4f}') if val_f1 > best_val_f1: best_val_f1 = val_f1 # Optional: Save best model checkpoint per trial if needed # save_path = f"hpo_checkpoints/trial_{trial.number}_best_f1_{best_val_f1:.4f}.pth" # os.makedirs(os.path.dirname(save_path), exist_ok=True) # torch.save(model.state_dict(), save_path) trial.report(val_f1, epoch) if trial.should_prune(): logging.debug(f"Trial {trial.number} pruned at epoch {epoch_num}.") raise optuna.exceptions.TrialPruned() except optuna.exceptions.TrialPruned: return best_val_f1 # Return best score before pruning except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error(f"Trial {trial.number} failed: CUDA OOM. Pruning.", exc_info=False) # Less verbose OOM log raise optuna.exceptions.TrialPruned("CUDA out of memory") else: logging.error(f"Trial {trial.number} failed with unexpected RuntimeError: {e}", exc_info=True) raise optuna.exceptions.TrialPruned(f"RuntimeError: {e}") # Prune on other runtime errors except Exception as e: logging.error(f"Trial {trial.number} failed with unexpected error: {e}", exc_info=True) raise optuna.exceptions.TrialPruned(f"Unexpected error: {e}") # Prune on other errors logging.info(f"Trial {trial.number} finished. Best Val F1: {best_val_f1:.4f}") return best_val_f1 # --- Main HPO Execution Block --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run TimeMixerPP HPO using Optuna and YAML config.") parser.add_argument("config_path", type=str, help="Path to the YAML configuration file.") parser.add_argument("--device", type=str, default=None, help="Override device (e.g., cuda:0, cpu). If None, determined automatically or by CUDA_VISIBLE_DEVICES.") cli_args = parser.parse_args() # --- 1. Load Configuration --- try: with open(cli_args.config_path, 'r') as f: config = yaml.safe_load(f) logging.info(f"Loaded configuration from: {cli_args.config_path}") except FileNotFoundError: logging.error(f"Configuration file not found at: {cli_args.config_path}") sys.exit(1) except yaml.YAMLError as e: logging.error(f"Error parsing YAML configuration file: {e}") sys.exit(1) except Exception as e: logging.error(f"An unexpected error occurred loading the config: {e}") sys.exit(1) # --- 2. Determine Device --- if cli_args.device: device_str = cli_args.device logging.info(f"Using device specified via command line: {device_str}") elif "CUDA_VISIBLE_DEVICES" in os.environ: # If CUDA_VISIBLE_DEVICES is set, usually to a single ID for parallel runs, use cuda:0 # Assumes the environment variable correctly isolates the GPU if os.environ["CUDA_VISIBLE_DEVICES"].strip(): device_str = "cuda:0" logging.info(f"Using cuda:0 based on CUDA_VISIBLE_DEVICES='{os.environ['CUDA_VISIBLE_DEVICES']}'") else: # If set to empty string, means no GPU should be visible device_str = "cpu" logging.info("CUDA_VISIBLE_DEVICES is empty, using CPU.") elif torch.cuda.is_available(): device_str = "cuda:0" # Default to first GPU if available and not overridden logging.info("CUDA available, default device: cuda:0") else: device_str = "cpu" logging.info("CUDA not available, using CPU.") # Validate device choice try: if "cuda" in device_str: if not torch.cuda.is_available(): logging.warning(f"Requested CUDA device {device_str} but CUDA not available. Falling back to CPU.") device = torch.device("cpu") else: # Check if the specific device index is valid try: gpu_index = int(device_str.split(':')[-1]) if gpu_index >= torch.cuda.device_count(): logging.warning(f"Requested GPU index {gpu_index} is invalid (found {torch.cuda.device_count()} GPUs). Using cuda:0.") device = torch.device("cuda:0") else: device = torch.device(device_str) except (ValueError, IndexError): logging.warning(f"Invalid CUDA device format: '{device_str}'. Using cuda:0.") device = torch.device("cuda:0") else: device = torch.device("cpu") logging.info(f"Final device selected for this run: {device}") except Exception as e: logging.error(f"Error setting up device '{device_str}': {e}. Exiting.") sys.exit(1) # --- 3. Load Data Once --- logging.info("--- Starting HPO Setup ---") logging.info("Loading or preprocessing data (once)...") try: _all_data, _all_labels, _actual_n_channels = load_or_preprocess_data(config) if not _all_data or _all_labels is None or _actual_n_channels is None: raise ValueError("Data loading returned invalid results.") PRELOADED_DATA = (_all_data, _all_labels, _actual_n_channels) logging.info(f"Data loaded successfully. Actual N_CHANNELS determined: {_actual_n_channels}") # Optionally update config if n_channels needs to be stored globally, but model init uses the value directly # config['fixed_params']['n_channels'] = _actual_n_channels except Exception as e: logging.error(f"Fatal error during data loading: {e}", exc_info=True) sys.exit(1) # --- 4. Configure Optuna Study --- hpo_cfg = config['hpo_settings'] study_name = hpo_cfg['study_name'] storage_name = hpo_cfg.get('storage') # None if not specified (in-memory) direction = hpo_cfg.get('direction', 'maximize') # Instantiate Pruner from config pruner_cfg = hpo_cfg.get('pruner') pruner = None if pruner_cfg and pruner_cfg.get('type'): try: pruner_class = getattr(optuna.pruners, pruner_cfg['type']) pruner_params = pruner_cfg.get('params', {}) pruner = pruner_class(**pruner_params) logging.info(f"Using pruner: {pruner_cfg['type']} with params {pruner_params}") except AttributeError: logging.error(f"Pruner type '{pruner_cfg['type']}' not found in optuna.pruners.") sys.exit(1) except Exception as e: logging.error(f"Error initializing pruner: {e}") sys.exit(1) else: logging.info("No pruner configured or pruner type missing.") logging.info(f"Creating/Loading Optuna study '{study_name}' (Storage: {storage_name or 'in-memory'})...") try: study = optuna.create_study( study_name=study_name, storage=storage_name, direction=direction, pruner=pruner, load_if_exists=True ) except Exception as e: logging.error(f"Error creating or loading Optuna study: {e}", exc_info=True) sys.exit(1) # --- 5. Run Optimization --- n_trials = hpo_cfg['n_trials'] n_jobs = hpo_cfg.get('n_jobs', 1) # Default to 1 job (sequential) timeout = hpo_cfg.get('timeout') # Optional timeout in seconds if n_jobs > 1 and device.type == 'cuda': logging.warning(f"n_jobs={n_jobs} requested with CUDA device. Optuna's default parallel execution might lead to GPU contention.") logging.warning("For multi-GPU parallelism, it's recommended to set n_jobs=1 in the config and launch this script multiple times,") logging.warning("each targeting a different GPU (e.g., using CUDA_VISIBLE_DEVICES=0, CUDA_VISIBLE_DEVICES=1, ...).") logging.warning("Ensure you are using a shared Optuna storage (e.g., SQLite or PostgreSQL) for parallel runs.") # Consider forcing n_jobs=1 or exiting if a single CUDA device is detected with n_jobs > 1 # if torch.cuda.device_count() == 1: # logging.info("Detected single GPU, forcing n_jobs=1 for stability.") # n_jobs = 1 logging.info(f"Starting Optuna optimization with {n_trials} trials...") logging.info(f" Parallel jobs (n_jobs): {n_jobs}") logging.info(f" Timeout: {timeout or 'Not set'}") logging.info(f" Device for this process: {device}") # Use functools.partial to pass fixed arguments (config, data, device) to the objective objective_with_args = functools.partial(objective, config=config, preloaded_data=PRELOADED_DATA, device=device) try: study.optimize( objective_with_args, n_trials=n_trials, n_jobs=n_jobs, # Pass n_jobs for potential parallelism timeout=timeout ) except Exception as e: logging.error(f"An error occurred during study optimization: {e}", exc_info=True) # The study might be partially completed. # --- 6. Report Results --- logging.info("\n--- HPO Finished ---") # Ensure study has trials before accessing attributes if study.trials: logging.info(f"Study statistics: ") logging.info(f" Number of finished trials: {len(study.trials)}") pruned_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.PRUNED])) complete_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.COMPLETE])) failed_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.FAIL])) logging.info(f" Pruned trials: {pruned_trials}") logging.info(f" Complete trials: {complete_trials}") logging.info(f" Failed trials: {failed_trials}") try: best_trial = study.best_trial logging.info(f"Best trial:") logging.info(f" Number: {best_trial.number}") logging.info(f" Value (Best {direction} F1): {best_trial.value:.6f}") logging.info(f" Params (Suggested): ") for key, value in best_trial.params.items(): # Parameters suggested by Optuna logging.info(f" {key}: {value}") # Also log the combined HPO parameters used (including fixed ones like batch_size) if best_trial.user_attrs.get("hpo_params"): logging.info(f" User Attributes (Full HPO Params used):") for key, value in best_trial.user_attrs["hpo_params"].items(): logging.info(f" {key}: {value}") except ValueError: logging.warning("Could not determine the best trial (perhaps no trials completed successfully).") except Exception as e: logging.error(f"An error occurred retrieving best trial info: {e}") # Optional: Save study results dataframe try: results_df_path = f"{study_name}_results.csv" df = study.trials_dataframe() df.to_csv(results_df_path, index=False) logging.info(f"Study results saved to {results_df_path}") except Exception as e: logging.error(f"Failed to save study results dataframe: {e}") else: logging.warning("No trials were recorded in the study.") logging.info("HPO script finished.")
아래 제안들은 **“현재 TimeMixerPP + Optuna 파이프라인”**을 2025년 시점 SOTA로 끌어올리기 위한 아키텍처·훈련·실험 조직 전략입니다. 단계별로 필요한 모듈·레퍼런스와 함께 설명했으니, 맞는 것부터 하나씩 적용해 보세요.
| 방향 | 핵심 아이디어 | 적용 힌트 |
|---|---|---|
| (a) TS‑Mamba/SSM 블록으로 전면 교체 | Mamba 기반 Selective‑SSM은 긴 시퀀스에서 Transformer 수준 성능을 선형 시간에 달성합니다. state‑spaces/mamba 구현을 가져와 TimeMixerPP_Model 대신 encoder 스택으로 사용하거나 Hybrid(↓) 로 병합. GitHub arXiv | |
| (b) TimeMixerPP × Mamba Hybrid | coarse‑scale는 TimeMixerPP, fine‑scale은 Mamba 블록 → nn.ModuleList에 서로 다른 block 넣고 skip‑connection으로 합산. | |
| (c) 사전학습 Foundation Model 파인튜닝 | 2025년 공개된 TS‑Mamba, Mamba4Cast, 또는 HuggingFace tsfm 체크포인트를 불러와 encoder weight를 고정 → 점진적 unfreeze. 1개 GPU에서도 적은 step만 fine‑tune해도 큰 이득. arXiv OpenReview |
코드 스케치
pythonfrom mamba_ssm import MambaBlock class MambaEncoder(nn.Module): def __init__(self, d_model, depth): super().__init__() self.blocks = nn.ModuleList( [MambaBlock(d_model) for _ in range(depth)]) def forward(self, x): for blk in self.blocks: x = blk(x) return x
TimeMixerPPClassifier 에서 self.timemixer_encoder를 위 모듈로 치환하거나, 두 encoder를 병렬→Concat 후 projector = nn.Linear(sum_dim, d_model)로 통일하세요.
seq_len을 {L, L/2, L/4, …}로 다운샘플링 후 각기 다른 encoder에 투입 → 마지막에 gated additive fusion(GatedFusion = σ(W₁·[h₀;h₁])+ (1-σ)…).| 내용 | 구현 팁 |
|---|---|
| Class‑Balanced Focal Loss (per‑task) | alpha = effective‑num / Σeffective‑num; gamma=2.0 |
| Label‑Aware Mixup/CutMix | 샘플 간 같은 라벨만 섞기 → 긍정·부정 비율 유지. |
| 평가 | 평균 Macro‑F1 외에 subset‑accuracy, micro‑F1 로깅; WandB table에 저장. |
val_loss가 이전 best보다 10% 높으면 prune.encoder_type ∈ {TimeMixerPP, Mamba, Hybrid}d_model: [128‑768] log, depth: [4‑24], kernel_size 등augmentation ⊂ {jitter, scaling, time_warp} (categorical multi‑choice)| 개선 | 세부 내용 |
|---|---|
| Stratified GroupKFold | subject_id 단위 그룹화를 유지해 leakage 방지. |
| TorchData/Streaming | Parquet → IterableDataset 로 lazy‑load, GPU‑host pin_memory↓ |
| Mixed Precision & Gradient Checkpointing | 이미 scaler 사용 → encoder에 torch.utils.checkpoint 적용 시 VRAM 20‑40 %↓ |
| SWA + EMA | 마지막 20 % epoch 동안 SWA 업데이트, EMA(0.999) 병행 → 일반화 향상. |
최근 Survey에서 권장하는 직교다항(Relative) + ∆Time 결합 방식으로 변경해 Transformer류 성능 1‑2 포인트↑. arXiv
pythonclass LearnedRelPos(nn.Module): def __init__(self, L, d): super().__init__() self.rel = nn.Parameter(torch.randn(L, d)) def forward(self, T): return self.rel[:T]
Encoder 입력에 x + pos[:T] + delta_t_embed.
작은 단위로 실험 로그를 남기며 벽돌 쌓듯 진행하면, 전체 Macro‑F1이 5‑10 pt 정도 상승하는 것을 기대할 수 있습니다. 필요하면 특정 섹션별 코드 스니펫 더 드리겠습니다—편하게 말씀 주세요!