현재 multi label multiclass분류 학습 파이프라인을 SOTA수준으로 아키텍...

생성일: 2025년 5월 3일

답변에 사용된 모델 O3 by Chat01

질문

현재 multi label multiclass분류 학습 파이프라인을 SOTA수준으로 아키텍처를 진보시켜야합니다:

python
# train_hpo.py import argparse import os import random import warnings import logging import sys import pickle import yaml # For loading YAML config import functools # For passing args to objective from tqdm import tqdm import numpy as np import pandas as pd import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.amp import GradScaler, autocast from torch.utils.data import Dataset, DataLoader, default_collate from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score import optuna # --- TimeMixerPP Imports and Patches --- # Assume TimeMixerPP directory is in PYTHONPATH or current dir try: sys.path.append('./TimeMixerPP') # Ensure TimeMixerPP is findable from models.TimeMixerPP import Model as TimeMixerPP_Model, MixerBlock from layers.Embed import DataEmbedding_wo_pos from utils.tools import dotdict # Keep dotdict if used internally by TimeMixerPP from utils.metrics import metric import models.TimeMixerPP as timemixerpp_module except ImportError as e: print(f"Error importing TimeMixerPP modules: {e}") print("Please ensure the TimeMixerPP directory is in your Python path or run from the TimeMixerPP root.") sys.exit(1) # --- Logging Setup --- log_level = logging.INFO # Default, can be overridden by config later if needed logging.basicConfig(level=log_level, format='%(asctime)s - %(levelname)s - [%(process)d] %(message)s', stream=sys.stderr) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=RuntimeWarning) # --- FFT Patch --- # --- (Patched FFT code - unchanged from previous version) --- try: original_FFT_for_Period = timemixerpp_module.FFT_for_Period except AttributeError: logging.warning("Could not find original FFT_for_Period function."); original_FFT_for_Period = None def patched_FFT_for_Period(x, k=2): if not isinstance(x, torch.Tensor): raise TypeError(f"Input x must be a torch.Tensor, got {type(x)}") original_dtype = x.dtype if x.dtype != torch.float32: x_float32 = x.to(torch.float32) else: x_float32 = x try: xf = torch.fft.rfft(x_float32, dim=1) except Exception as e: logging.error(f"Error during torch.fft.rfft: {e}"); logging.error(f"Input shape: {x_float32.shape}, dtype: {x_float32.dtype}, device: {x_float32.device}"); raise e if not torch.is_complex(xf): logging.warning(f"torch.fft.rfft output is not complex (dtype: {xf.dtype}). This might indicate an issue."); xf_abs = abs(xf) else: xf_abs = abs(xf) if xf_abs.numel() == 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) frequency_list = xf_abs.mean(0); if frequency_list.dim() > 1: frequency_list = frequency_list.mean(-1) frequency_list[0] = 0 actual_k = min(k, len(frequency_list)) if actual_k <= 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) try: _, top_list_indices = torch.topk(frequency_list, actual_k) except RuntimeError as e: logging.error(f"Error in torch.topk: {e}. frequency_list size: {len(frequency_list)}, k: {actual_k}") if actual_k > len(frequency_list): actual_k = len(frequency_list) if actual_k == 0: return np.array([], dtype=np.int64), torch.tensor([]).to(x.device).reshape(x.shape[0], 0), np.array([]) _, top_list_indices = torch.topk(frequency_list, actual_k) else: raise e top_list_indices_cpu = top_list_indices.detach().cpu().numpy() periods = np.zeros_like(top_list_indices_cpu, dtype=float) valid_indices_mask = top_list_indices_cpu != 0 safe_indices = top_list_indices_cpu[valid_indices_mask] if np.any(safe_indices == 0): logging.warning("Found zero index in FFT top list after initial filtering. Removing.") zero_mask = safe_indices != 0 safe_indices = safe_indices[zero_mask] valid_indices_mask = np.isin(top_list_indices_cpu, safe_indices) if np.any(valid_indices_mask): periods[valid_indices_mask] = x.shape[1] / top_list_indices_cpu[valid_indices_mask] valid_period_mask = periods > 0 top_list_filtered = top_list_indices_cpu[valid_period_mask]; periods_filtered = periods[valid_period_mask] periods_filtered_int = np.round(periods_filtered).astype(np.int64) non_zero_mask_after_round = periods_filtered_int > 0 periods_filtered_int = periods_filtered_int[non_zero_mask_after_round]; top_list_filtered = top_list_filtered[non_zero_mask_after_round] top_list_indices_tensor = torch.from_numpy(top_list_filtered).long().to(frequency_list.device) if top_list_indices_tensor.numel() > 0: weights_all_freqs = xf_abs if weights_all_freqs.dim() == 1: weights_all_freqs = weights_all_freqs.unsqueeze(0) if weights_all_freqs.dim() == 3: weights_all_freqs = weights_all_freqs.mean(-1) elif weights_all_freqs.dim() != 2: raise ValueError(f"Unexpected shape for weights_all_freqs: {weights_all_freqs.shape}") max_freq_index = weights_all_freqs.shape[1] - 1 valid_fft_indices_mask = top_list_indices_tensor <= max_freq_index valid_fft_indices = top_list_indices_tensor[valid_fft_indices_mask] if len(valid_fft_indices) != len(top_list_indices_tensor): logging.warning(f"Some top FFT indices out of bounds. Filtering.") periods_filtered_int = periods_filtered_int[valid_fft_indices_mask.cpu().numpy()] top_list_filtered = top_list_filtered[valid_fft_indices_mask.cpu().numpy()] top_list_indices_tensor = valid_fft_indices if top_list_indices_tensor.numel() > 0: period_weights = weights_all_freqs[:, top_list_indices_tensor] else: period_weights = torch.tensor([]).to(x.device).reshape(xf_abs.shape[0], 0) else: period_weights = torch.tensor([]).to(x.device).reshape(xf_abs.shape[0], 0) if not isinstance(periods_filtered_int, np.ndarray): raise TypeError("periods_filtered_int is not np.ndarray") if not isinstance(period_weights, torch.Tensor): raise TypeError("period_weights is not torch.Tensor") if not isinstance(top_list_filtered, np.ndarray): raise TypeError("top_list_filtered is not np.ndarray") return periods_filtered_int, period_weights, top_list_filtered if hasattr(timemixerpp_module, 'FFT_for_Period') and original_FFT_for_Period is not None: timemixerpp_module.FFT_for_Period = patched_FFT_for_Period logging.debug("Successfully patched 'models.TimeMixerPP.FFT_for_Period'.") elif not hasattr(timemixerpp_module, 'FFT_for_Period'): logging.warning("Could not find 'models.TimeMixerPP.FFT_for_Period'. Patching skipped.") else: logging.warning("Original FFT_for_Period not found during setup. Patching skipped.") # --- multi_reso_mixing Patch --- # --- (Patched multi_reso_mixing code - unchanged from previous version) --- try: original_multi_reso_mixing = MixerBlock.multi_reso_mixing except AttributeError: logging.warning("Could not find original MixerBlock.multi_reso_mixing method."); original_multi_reso_mixing = None def patched_multi_reso_mixing(self, period_weight, x, res): B_actual, T_actual, N_actual = x.size() if period_weight.shape[1] > 0: period_weight_softmax = F.softmax(period_weight, dim=1) period_weight_reshaped = period_weight_softmax.unsqueeze(1).unsqueeze(1) period_weight_repeated = period_weight_reshaped.repeat(1, T_actual, N_actual, 1) res_aggregated = torch.sum(res * period_weight_repeated, -1) else: res_aggregated = torch.zeros_like(x) res_final = res_aggregated + x return res_final if hasattr(MixerBlock, 'multi_reso_mixing') and original_multi_reso_mixing is not None: MixerBlock.multi_reso_mixing = patched_multi_reso_mixing logging.debug("Successfully patched 'models.TimeMixerPP.MixerBlock.multi_reso_mixing'.") elif not hasattr(MixerBlock, 'multi_reso_mixing'): logging.warning("Could not find 'models.TimeMixerPP.MixerBlock.multi_reso_mixing'. Patching skipped.") else: logging.warning("Original multi_reso_mixing not found during setup. Patching skipped.") # --- Utility Functions --- def control_randomness(seed: int = 42): random.seed(seed); os.environ["PYTHONHASHSEED"] = str(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed(seed) if torch.cuda.is_available(): torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False; torch.cuda.manual_seed_all(seed) # --- Data Loading & Preprocessing (Takes config dict) --- def load_or_preprocess_data(config: dict): data_cfg = config['data_settings'] fixed_cfg = config['fixed_params'] cache_dir = data_cfg.get('cache_dir', './preprocess_cache') # Use get for optional cache cache_data_path = os.path.join(cache_dir, 'processed_data.pkl') cache_labels_path = os.path.join(cache_dir, 'processed_labels.pkl') cache_n_channels_path = os.path.join(cache_dir, 'n_channels.txt') os.makedirs(cache_dir, exist_ok=True) # --- (Data loading/preprocessing code - uses paths/params from config) --- if (data_cfg.get('use_cache', True) and # Check if caching enabled os.path.exists(cache_data_path) and os.path.exists(cache_labels_path) and os.path.exists(cache_n_channels_path)): logging.info(f"Loading preprocessed data from cache: {cache_dir}") try: with open(cache_data_path, 'rb') as f: all_data = pickle.load(f) with open(cache_labels_path, 'rb') as f: all_labels = pickle.load(f) with open(cache_n_channels_path, 'r') as f: actual_n_channels = int(f.read()) logging.info(f"Loaded {len(all_data)} samples from cache. Actual N_CHANNELS: {actual_n_channels}") # --- Add Validation: Check if cached seq_len matches config --- if len(all_data) > 0 and all_data[0].shape[0] != data_cfg['seq_len']: logging.warning(f"Cached data seq_len ({all_data[0].shape[0]}) != config seq_len ({data_cfg['seq_len']}). Reprocessing...") raise ValueError("Sequence length mismatch in cache.") # --- Add Validation: Check if cached channels matches config (if provided) --- # This is tricky as n_channels is usually *determined* by data, not set in config. # We'll rely on the preprocess function to determine it and save it. return all_data, all_labels, actual_n_channels except Exception as e: logging.warning(f"Failed to load data from cache: {e}. Reprocessing...") if os.path.exists(cache_data_path): os.remove(cache_data_path) if os.path.exists(cache_labels_path): os.remove(cache_labels_path) if os.path.exists(cache_n_channels_path): os.remove(cache_n_channels_path) logging.info("Cache not found, loading failed, or cache disabled. Starting data preprocessing...") all_data, all_labels, actual_n_channels = preprocess_data_timeseries(config) # Pass full config if data_cfg.get('use_cache', True): logging.info(f"Caching preprocessed data to {cache_dir}...") try: with open(cache_data_path, 'wb') as f: pickle.dump(all_data, f) with open(cache_labels_path, 'wb') as f: pickle.dump(all_labels, f) with open(cache_n_channels_path, 'w') as f: f.write(str(actual_n_channels)) logging.info("Successfully cached preprocessed data.") except Exception as e: logging.error(f"Failed to cache preprocessed data: {e}") return all_data, all_labels, actual_n_channels def preprocess_data_timeseries(config: dict): data_cfg = config['data_settings'] fixed_cfg = config['fixed_params'] target_seq_len = data_cfg['seq_len'] label_names = fixed_cfg['label_names'] time_series_path = data_cfg['train_path'] label_path = data_cfg['label_path'] # n_channels_config is not used here, it's determined from data logging.info(f"Loading raw data from {time_series_path} and {label_path}...") try: ts_df = pd.read_parquet(time_series_path) label_df = pd.read_csv(label_path) except FileNotFoundError as e: logging.error(f"Error loading raw data files: {e}"); raise # --- (Rest of preprocessing logic - unchanged, uses local vars derived from config) --- if 'timestamp' in ts_df.columns: ts_df['timestamp'] = pd.to_datetime(ts_df['timestamp']) ts_df['lifelog_date'] = ts_df['timestamp'].dt.date else: if 'lifelog_date' not in ts_df.columns: raise ValueError("Time series data must contain 'timestamp' or 'lifelog_date'.") ts_df['lifelog_date'] = pd.to_datetime(ts_df['lifelog_date']).dt.date if 'lifelog_date' in label_df.columns: label_df['lifelog_date'] = pd.to_datetime(label_df['lifelog_date']).dt.date else: raise ValueError("Label data must contain 'lifelog_date'.") potential_keys = ['timestamp', 'subject_id', 'lifelog_date'] sensor_cols = [col for col in ts_df.columns if col not in potential_keys] if not sensor_cols: raise ValueError("Could not automatically identify sensor columns.") actual_n_channels = len(sensor_cols) n_channels = actual_n_channels # Use actual channel count logging.info(f"Identified {n_channels} sensor columns: {sensor_cols[:5]}...") logging.info("Merging time series data with labels...") merge_keys = [] if 'subject_id' in ts_df.columns and 'subject_id' in label_df.columns: merge_keys.append('subject_id') if 'lifelog_date' in ts_df.columns and 'lifelog_date' in label_df.columns: merge_keys.append('lifelog_date') if not merge_keys: raise ValueError("Cannot merge dataframes. Missing common keys ('subject_id' and/or 'lifelog_date').") logging.info(f"Merging on keys: {merge_keys}") required_label_cols = merge_keys + label_names missing_label_cols = [col for col in required_label_cols if col not in label_df.columns] if missing_label_cols: raise ValueError(f"Label dataframe is missing required columns: {missing_label_cols}") merged_df = pd.merge(ts_df, label_df[required_label_cols], on=merge_keys, how='inner') logging.info(f"Merged dataframe shape: {merged_df.shape}") if merged_df.empty: raise ValueError("Merged dataframe is empty. Check merge keys and data content.") logging.info(f"Processing time series data per day for target length {target_seq_len}...") processed_data_list = [] processed_labels = [] grouping_keys = [key for key in merge_keys if key in merged_df.columns] if not grouping_keys: raise ValueError("Grouping keys not found after merge.") grouped = merged_df.groupby(grouping_keys) logging.info(f"Grouping data by: {grouping_keys}. Number of groups: {len(grouped)}") for name, group in tqdm(grouped, desc="Processing Days"): try: labels = group.iloc[0][label_names].values.astype(np.int64).tolist() except Exception as e: logging.error(f"Error converting labels for group {name}: {e}. Labels: {group.iloc[0][label_names].values}"); continue day_data = group[sensor_cols].fillna(0).values.astype(np.float32) current_len = day_data.shape[0] if current_len == 0: logging.warning(f"Skipping group {name} due to empty sensor data."); continue if current_len == target_seq_len: processed_day_data = day_data elif current_len < target_seq_len: padding_len = target_seq_len - current_len padding = np.zeros((padding_len, n_channels), dtype=np.float32) processed_day_data = np.concatenate([day_data, padding], axis=0) else: processed_day_data = day_data[-target_seq_len:, :] if processed_day_data.shape != (target_seq_len, n_channels): logging.error(f"Shape mismatch for group {name} after processing: expected {(target_seq_len, n_channels)}, got {processed_day_data.shape}. Skipping."); continue processed_data_list.append(processed_day_data) processed_labels.append(labels) logging.info(f"Finished processing. Found {len(processed_data_list)} valid samples.") if not processed_data_list: raise ValueError("No data processed. Check input files and processing logic.") labels_array = np.array(processed_labels, dtype=np.int64) return processed_data_list, labels_array, n_channels # --- Dataset Class --- class TimeSeriesDataset(Dataset): # --- (Dataset code - unchanged) --- def __init__(self, data_list, labels): self.data_list = data_list self.labels = labels def __len__(self): return len(self.labels) def __getitem__(self, idx): series_data_np = self.data_list[idx] if series_data_np.dtype != np.float32: series_data_np = series_data_np.astype(np.float32) series_data = torch.from_numpy(series_data_np) label_vector_np = self.labels[idx] if label_vector_np.dtype != np.int64: label_vector_np = label_vector_np.astype(np.int64) label_vector = torch.from_numpy(label_vector_np) if not isinstance(series_data, torch.Tensor) or not isinstance(label_vector, torch.Tensor): raise TypeError(f"Dataset __getitem__ returned non-Tensor types for index {idx}. Got {type(series_data)} and {type(label_vector)}") return series_data, label_vector # --- Model Definitions --- class IndependentMLPHeads(nn.Module): # --- (MLP Head code - unchanged) --- def __init__(self, input_dim, num_classes_per_label, label_names, hidden_dim=512, dropout=0.1): super().__init__() self.heads = nn.ModuleDict() for i, name in enumerate(label_names): num_classes = num_classes_per_label[i] self.heads[name] = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, num_classes) ) def forward(self, x): outputs = {} for name, head in self.heads.items(): outputs[name] = head(x) return outputs class TimeMixerPPClassifier(nn.Module): # --- (Classifier code - takes config dict for init) --- def __init__(self, model_config: dict, num_classes_per_label: list, label_names: list): super().__init__() self.configs = argparse.Namespace(**model_config) # Use Namespace for attribute access self.label_names = label_names self.timemixer_encoder = TimeMixerPP_Model(self.configs) if self.configs.channel_independence == 1: encoder_output_dim = self.configs.d_model else: encoder_output_dim = self.configs.d_model; logging.warning("Channel dependence used. Ensure aggregation handles multiple channels.") self.classifier_head = IndependentMLPHeads( input_dim=encoder_output_dim, num_classes_per_label=num_classes_per_label, label_names=label_names, hidden_dim=self.configs.mlp_hidden_dim, # Get from model_config dropout=self.configs.mlp_dropout # Get from model_config ) def forward(self, x_enc, x_mark_enc=None): # --- (Forward pass logic - unchanged) --- try: x_enc_list, x_mark_enc_list = self.timemixer_encoder._Model__multi_scale_process_inputs(x_enc, x_mark_enc) except AttributeError: logging.error("Could not find '_Model__multi_scale_process_inputs'."); raise except Exception as e: logging.error(f"Error during multi_scale_process_inputs: {e}"); raise x_processed_list = []; x_mark_processed_list = [] revin_layers = self.timemixer_encoder.revin_layers marks_to_process = x_mark_enc_list if x_mark_enc_list is None: marks_to_process = [None] * len(x_enc_list) elif len(x_enc_list) != len(x_mark_enc_list): raise ValueError(f"Mismatch lengths x_enc_list/x_mark_enc_list.") for i, (x, x_mark) in enumerate(zip(x_enc_list, marks_to_process)): if hasattr(self.configs, 'revin') and self.configs.revin == 1: if i < len(revin_layers): x = revin_layers[i](x, 'norm') else: logging.warning(f"RevIN layer index {i} out of bounds. Skipping.") if self.configs.channel_independence == 1: B, T_i, N_i = x.shape x = x.permute(0, 2, 1).contiguous().reshape(B * N_i, T_i, 1) if x_mark is not None: x_mark = x_mark.repeat_interleave(N_i, dim=0) x_processed_list.append(x); x_mark_processed_list.append(x_mark) if self.configs.channel_mixing and self.configs.channel_independence == 1: _, T_coarse, D_model = x_processed_list[-1].size() B_orig, N_orig = x_enc.shape[0], self.configs.enc_in coarse_scale_enc_out_flat = x_processed_list[-1].reshape(B_orig, N_orig, T_coarse * D_model) if hasattr(self.timemixer_encoder, 'channel_mixing_attention'): coarse_scale_enc_out_mixed, _ = self.timemixer_encoder.channel_mixing_attention(coarse_scale_enc_out_flat, coarse_scale_enc_out_flat, coarse_scale_enc_out_flat, None) x_processed_list[-1] = coarse_scale_enc_out_mixed.reshape(B_orig * N_orig, T_coarse, D_model) + x_processed_list[-1] else: logging.warning("Channel mixing enabled, but layer not found. Skipping.") enc_out_list = [] embedding_layer = self.timemixer_encoder.enc_embedding if not hasattr(embedding_layer, '__call__'): raise AttributeError("No callable 'enc_embedding'.") for x, x_mark in zip(x_processed_list, x_mark_processed_list): enc_out = embedding_layer(x, x_mark); enc_out_list.append(enc_out) processed_enc_out_list = enc_out_list encoder_layers = self.timemixer_encoder.encoder_model if not isinstance(encoder_layers, nn.ModuleList): raise AttributeError("No 'encoder_model' ModuleList.") for i in range(self.configs.e_layers): if i >= len(encoder_layers): logging.warning(f"Configured e_layers > actual layers. Stopping."); break layer = encoder_layers[i] if not callable(layer): raise TypeError(f"Encoder layer {i} not callable.") processed_enc_out_list = layer(processed_enc_out_list) if not isinstance(processed_enc_out_list, list): raise TypeError(f"MixerBlock {i} didn't return list.") if not processed_enc_out_list: raise ValueError(f"MixerBlock {i} returned empty list.") if not isinstance(processed_enc_out_list[0], torch.Tensor): raise TypeError(f"MixerBlock {i} list has non-Tensor.") if not processed_enc_out_list: raise ValueError("Encoder output list empty.") final_enc_out = processed_enc_out_list[0] if not isinstance(final_enc_out, torch.Tensor): raise TypeError(f"Final enc_out not Tensor.") aggregated_features = final_enc_out.mean(dim=1) if self.configs.channel_independence == 1: B_actual, D = aggregated_features.shape N_ch = self.configs.enc_in if N_ch == 0: raise ValueError("N_ch cannot be zero.") # --- Add check for division by zero if N_ch might be 0 --- if N_ch > 0 and B_actual % N_ch != 0: logging.error(f"Shape mismatch B_actual={B_actual}, N_ch={N_ch}. Check drop_last=True.") raise ValueError("Shape mismatch.") B_inferred = B_actual // N_ch if N_ch > 0 else B_actual # Avoid division by zero # Reshape only if N_ch > 0 and reshaping is needed if N_ch > 0 and B_actual != B_inferred: aggregated_features = aggregated_features.view(B_inferred, N_ch, D).mean(dim=1) # else: # If N_ch=0 or B_actual == B_inferred, use as is (or handle error) # pass # Or raise error if N_ch == 0 is unexpected outputs = self.classifier_head(aggregated_features) return outputs # --- Evaluation Metric --- def calculate_avg_macro_f1(all_true_labels, all_pred_logits, label_names): # --- (F1 calculation code - unchanged) --- f1_scores = {} num_samples = all_true_labels.shape[0] for i, name in enumerate(label_names): true_task = all_true_labels[:, i] pred_task_logits = all_pred_logits[name].cpu() pred_task_labels = torch.argmax(pred_task_logits, dim=1).numpy() true_task_np = true_task.cpu().numpy() if len(np.unique(true_task_np)) < 2 and len(np.unique(pred_task_labels)) < 2: f1 = 1.0 if np.array_equal(true_task_np, pred_task_labels) else 0.0 else: f1 = f1_score(true_task_np, pred_task_labels, average='macro', zero_division=0) f1_scores[name] = f1 avg_f1 = np.mean(list(f1_scores.values())) return avg_f1, f1_scores # --- Training and Evaluation Functions (Take config) --- criterion = nn.CrossEntropyLoss() def train_epoch_timemixer(model, device, train_dataloader, criterion, optimizer, scheduler, scaler, config: dict, current_epoch: int): # --- (Train epoch code - largely unchanged, uses config for params) --- model.train() total_loss = 0.0 optimizer.zero_grad() amp_enabled = scaler.is_enabled() label_names = config['fixed_params']['label_names'] accumulation_steps = config['hpo_params']['accumulation_steps'] # Get from HPO params now gradient_clip_val = config['hpo_params'].get('gradient_clip_val') # Optional clip val # No tqdm for HPO runs by default to reduce log spam # pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {current_epoch} Train", leave=False) for i, (batch_x, batch_labels) in enumerate(train_dataloader): batch_x = batch_x.to(device) batch_labels = batch_labels.to(device) batch_x_mark = None with autocast(device_type=device.type, dtype=torch.float16, enabled=amp_enabled): outputs = model(batch_x, batch_x_mark) batch_loss = 0.0 for task_idx, name in enumerate(label_names): loss = criterion(outputs[name], batch_labels[:, task_idx]) batch_loss += loss batch_loss /= len(label_names) loss_accum = batch_loss / accumulation_steps scaler.scale(loss_accum).backward() if (i + 1) % accumulation_steps == 0: if gradient_clip_val is not None and gradient_clip_val > 0: # Check > 0 scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip_val) scaler.step(optimizer) scaler.update() optimizer.zero_grad() if scheduler is not None and isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): scheduler.step() total_loss += batch_loss.item() # pbar.set_postfix({'loss': batch_loss.item()}) avg_loss = total_loss / len(train_dataloader) if len(train_dataloader) > 0 else 0 if scheduler is not None and not isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR): scheduler.step() return avg_loss def evaluate_epoch_timemixer(model, device, dataloader, criterion, config: dict, phase='val', current_epoch=0): # --- (Evaluate epoch code - largely unchanged, uses config) --- model.eval() all_true_labels = [] all_pred_logits = {name: [] for name in config['fixed_params']['label_names']} total_loss = 0.0 amp_enabled = (device.type == 'cuda') label_names = config['fixed_params']['label_names'] with torch.no_grad(): # pbar = tqdm(dataloader, desc=f"Epoch {current_epoch} Eval {phase}", leave=False) for batch_x, batch_labels in dataloader: batch_x = batch_x.to(device) batch_labels = batch_labels.to(device) batch_x_mark = None with autocast(device_type=device.type, dtype=torch.float16, enabled=amp_enabled): outputs = model(batch_x, batch_x_mark) batch_loss = 0.0 for i, name in enumerate(label_names): task_logits = outputs[name] loss = criterion(task_logits, batch_labels[:, i]) batch_loss += loss all_pred_logits[name].append(task_logits.cpu().float()) batch_loss /= len(label_names) total_loss += batch_loss.item() all_true_labels.append(batch_labels.cpu()) # pbar.set_postfix({'loss': batch_loss.item()}) avg_loss = total_loss / len(dataloader) if len(dataloader) > 0 else 0 all_true_labels_cat = torch.cat(all_true_labels, dim=0) if all_true_labels else torch.empty(0, len(label_names)) all_pred_logits_cat = {name: torch.cat(all_pred_logits[name], dim=0) for name in label_names} if all_pred_logits[label_names[0]] else {name: torch.empty(0) for name in label_names} # Handle case with no evaluation data if all_true_labels_cat.numel() == 0: logging.warning("Evaluation dataloader was empty. Returning F1=0.") return avg_loss, 0.0, {name: 0.0 for name in label_names} avg_f1, f1_scores_dict = calculate_avg_macro_f1(all_true_labels_cat, all_pred_logits_cat, label_names) return avg_loss, avg_f1, f1_scores_dict # --- Optuna Objective Function (Takes config) --- def suggest_hyperparameters(trial: optuna.trial.Trial, search_space_config: dict): """Suggests hyperparameters based on the search space config.""" params = {} for name, cfg in search_space_config.items(): suggest_type = cfg['type'] if suggest_type == 'categorical': params[name] = trial.suggest_categorical(name, cfg['choices']) elif suggest_type == 'int': params[name] = trial.suggest_int(name, cfg['low'], cfg['high'], step=cfg.get('step', 1), log=cfg.get('log', False)) elif suggest_type == 'float': params[name] = trial.suggest_float(name, cfg['low'], cfg['high'], step=cfg.get('step'), log=cfg.get('log', False)) else: raise ValueError(f"Unsupported suggestion type: {suggest_type} for parameter {name}") return params def objective(trial: optuna.trial.Trial, config: dict, preloaded_data: tuple, device: torch.device) -> float: """Optuna objective function driven by YAML config.""" # --- 1. Suggest Hyperparameters based on Config --- try: hpo_params = suggest_hyperparameters(trial, config['search_space']) # Combine suggested HPO params with fixed HPO params from config hpo_params.update(config.get('hpo_params', {})) # Adds batch_size, accumulation_steps etc. if defined there # Store combined params for the trial (useful for analysis) trial.set_user_attr("hpo_params", hpo_params) except Exception as e: logging.error(f"Error suggesting hyperparameters: {e}", exc_info=True) raise # Re-raise to mark trial as failed # --- 2. Setup Trial --- control_randomness(config['fixed_params']['seed'] + trial.number) # Seed per trial for reproducibility *if needed* all_data, all_labels, n_channels = preloaded_data if not all_data or all_labels is None or n_channels is None: logging.error("Preloaded data is missing or invalid!") return -1.0 # Indicate failure # --- 3. Data Splitting and Loaders --- indices = np.arange(len(all_data)) stratify_col = all_labels[:, 0] if all_labels.shape[1] > 0 and len(np.unique(all_labels[:, 0])) >= 2 else None train_idx, val_idx = train_test_split( indices, test_size=config['data_settings']['val_split_ratio'], random_state=config['fixed_params']['seed'], # Fixed seed for split consistency across trials stratify=stratify_col ) train_dataset = TimeSeriesDataset([all_data[i] for i in train_idx], all_labels[train_idx]) val_dataset = TimeSeriesDataset([all_data[i] for i in val_idx], all_labels[val_idx]) # Use hpo_params for batch size and accumulation steps batch_size = hpo_params['batch_size'] accumulation_steps = hpo_params['accumulation_steps'] num_workers = config['hpo_settings'].get('num_workers', min(os.cpu_count() // 2 if os.cpu_count() else 0, 4)) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=(num_workers > 0 and device.type == 'cuda'), drop_last=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size * 2, shuffle=False, num_workers=num_workers, pin_memory=(num_workers > 0 and device.type == 'cuda'), drop_last=False ) # --- 4. Model Initialization --- # Combine fixed params and HPO params for model config model_config_dict = {**config['fixed_params'], **hpo_params} # HPO params override fixed ones if names clash model_config_dict['enc_in'] = n_channels # Use actual N channels model_config_dict['c_out'] = n_channels # Usually needed for TimeMixerPP init model_config_dict['seq_len'] = config['data_settings']['seq_len'] # Get from data settings model_config_dict['pred_len'] = 0 # Classification specific model_config_dict['label_len'] = 0 # Classification specific model_config_dict['task_name'] = 'classification' # Ensure d_ff is calculated if needed if 'd_ff_multiple' in hpo_params: model_config_dict['d_ff'] = model_config_dict['d_model'] * hpo_params['d_ff_multiple'] # Ensure MLP params are present in the main config dict # (They should already be there from the merge above if defined in search_space or fixed_params) if 'mlp_hidden_dim' not in model_config_dict: # Provide a default or raise error if it's mandatory logging.warning("mlp_hidden_dim not found in combined config, check YAML search_space/fixed_params.") # model_config_dict['mlp_hidden_dim'] = 256 # Example default raise ValueError("mlp_hidden_dim is required in the model configuration.") if 'mlp_dropout' not in model_config_dict: logging.warning("mlp_dropout not found in combined config, check YAML search_space/fixed_params.") # model_config_dict['mlp_dropout'] = 0.1 # Example default raise ValueError("mlp_dropout is required in the model configuration.") # Add potentially missing default TimeMixerPP args if not in config/hpo # These should be added to the main model_config_dict model_config_dict.setdefault('d_layers', 1) model_config_dict.setdefault('embed', 'timeF') model_config_dict.setdefault('freq', 'h') model_config_dict.setdefault('activation', 'gelu') model_config_dict.setdefault('output_attention', False) model_config_dict.setdefault('decomp_method', 'moving_avg') model_config_dict.setdefault('use_norm', 1) model_config_dict.setdefault('revin', 1) model_config_dict.setdefault('affine', 0) model_config_dict.setdefault('subtract_last', 0) model_config_dict.setdefault('decomposition', 0) model_config_dict.setdefault('kernel_size', 25) model_config_dict.setdefault('individual', 0) try: # Pass the combined dictionary as 'model_config' model = TimeMixerPPClassifier( model_config=model_config_dict, # <-- 수정: 올바른 인수 이름 사용 num_classes_per_label=config['fixed_params']['num_classes_per_label'], label_names=config['fixed_params']['label_names'] # <-- 수정: **classifier_params 제거 ) model.to(device) except Exception as e: logging.error(f"Error initializing model: {e}", exc_info=True) # 로그에 사용된 전체 설정 딕셔너리 출력 logging.error(f"Model Config used for init: {model_config_dict}") raise # Fail the trial # --- 5. Optimizer, Scheduler, Scaler --- optimizer = optim.AdamW(model.parameters(), lr=hpo_params['lr'], weight_decay=hpo_params['weight_decay']) epochs_per_trial = config['fixed_params']['epochs_per_trial'] num_optimizer_steps_per_epoch = len(train_loader) // accumulation_steps if len(train_loader) % accumulation_steps != 0: num_optimizer_steps_per_epoch +=1 if num_optimizer_steps_per_epoch == 0: num_optimizer_steps_per_epoch = 1 num_training_steps = epochs_per_trial * num_optimizer_steps_per_epoch # Optional: Make scheduler type/params configurable via YAML scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hpo_params['lr'], total_steps=num_training_steps, pct_start=hpo_params.get('pct_start', 0.1) # Get pct_start from HPO params or default ) scaler = GradScaler(device=device, enabled=(device.type == 'cuda')) # --- 6. Training Loop with Pruning --- best_val_f1 = -1.0 try: for epoch in range(epochs_per_trial): epoch_num = epoch + 1 train_loss = train_epoch_timemixer( model, device, train_loader, criterion, optimizer, scheduler, scaler, config, current_epoch=epoch_num # Pass full config ) val_loss, val_f1, val_f1_tasks = evaluate_epoch_timemixer( model, device, val_loader, criterion, config, phase='val', current_epoch=epoch_num # Pass config ) current_lr = optimizer.param_groups[0]['lr'] logging.debug(f'T{trial.number} E{epoch_num}/{epochs_per_trial} | LR: {current_lr:.2e} | TrL: {train_loss:.4f} | Vl L: {val_loss:.4f} | Vl F1: {val_f1:.4f}') if val_f1 > best_val_f1: best_val_f1 = val_f1 # Optional: Save best model checkpoint per trial if needed # save_path = f"hpo_checkpoints/trial_{trial.number}_best_f1_{best_val_f1:.4f}.pth" # os.makedirs(os.path.dirname(save_path), exist_ok=True) # torch.save(model.state_dict(), save_path) trial.report(val_f1, epoch) if trial.should_prune(): logging.debug(f"Trial {trial.number} pruned at epoch {epoch_num}.") raise optuna.exceptions.TrialPruned() except optuna.exceptions.TrialPruned: return best_val_f1 # Return best score before pruning except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error(f"Trial {trial.number} failed: CUDA OOM. Pruning.", exc_info=False) # Less verbose OOM log raise optuna.exceptions.TrialPruned("CUDA out of memory") else: logging.error(f"Trial {trial.number} failed with unexpected RuntimeError: {e}", exc_info=True) raise optuna.exceptions.TrialPruned(f"RuntimeError: {e}") # Prune on other runtime errors except Exception as e: logging.error(f"Trial {trial.number} failed with unexpected error: {e}", exc_info=True) raise optuna.exceptions.TrialPruned(f"Unexpected error: {e}") # Prune on other errors logging.info(f"Trial {trial.number} finished. Best Val F1: {best_val_f1:.4f}") return best_val_f1 # --- Main HPO Execution Block --- if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run TimeMixerPP HPO using Optuna and YAML config.") parser.add_argument("config_path", type=str, help="Path to the YAML configuration file.") parser.add_argument("--device", type=str, default=None, help="Override device (e.g., cuda:0, cpu). If None, determined automatically or by CUDA_VISIBLE_DEVICES.") cli_args = parser.parse_args() # --- 1. Load Configuration --- try: with open(cli_args.config_path, 'r') as f: config = yaml.safe_load(f) logging.info(f"Loaded configuration from: {cli_args.config_path}") except FileNotFoundError: logging.error(f"Configuration file not found at: {cli_args.config_path}") sys.exit(1) except yaml.YAMLError as e: logging.error(f"Error parsing YAML configuration file: {e}") sys.exit(1) except Exception as e: logging.error(f"An unexpected error occurred loading the config: {e}") sys.exit(1) # --- 2. Determine Device --- if cli_args.device: device_str = cli_args.device logging.info(f"Using device specified via command line: {device_str}") elif "CUDA_VISIBLE_DEVICES" in os.environ: # If CUDA_VISIBLE_DEVICES is set, usually to a single ID for parallel runs, use cuda:0 # Assumes the environment variable correctly isolates the GPU if os.environ["CUDA_VISIBLE_DEVICES"].strip(): device_str = "cuda:0" logging.info(f"Using cuda:0 based on CUDA_VISIBLE_DEVICES='{os.environ['CUDA_VISIBLE_DEVICES']}'") else: # If set to empty string, means no GPU should be visible device_str = "cpu" logging.info("CUDA_VISIBLE_DEVICES is empty, using CPU.") elif torch.cuda.is_available(): device_str = "cuda:0" # Default to first GPU if available and not overridden logging.info("CUDA available, default device: cuda:0") else: device_str = "cpu" logging.info("CUDA not available, using CPU.") # Validate device choice try: if "cuda" in device_str: if not torch.cuda.is_available(): logging.warning(f"Requested CUDA device {device_str} but CUDA not available. Falling back to CPU.") device = torch.device("cpu") else: # Check if the specific device index is valid try: gpu_index = int(device_str.split(':')[-1]) if gpu_index >= torch.cuda.device_count(): logging.warning(f"Requested GPU index {gpu_index} is invalid (found {torch.cuda.device_count()} GPUs). Using cuda:0.") device = torch.device("cuda:0") else: device = torch.device(device_str) except (ValueError, IndexError): logging.warning(f"Invalid CUDA device format: '{device_str}'. Using cuda:0.") device = torch.device("cuda:0") else: device = torch.device("cpu") logging.info(f"Final device selected for this run: {device}") except Exception as e: logging.error(f"Error setting up device '{device_str}': {e}. Exiting.") sys.exit(1) # --- 3. Load Data Once --- logging.info("--- Starting HPO Setup ---") logging.info("Loading or preprocessing data (once)...") try: _all_data, _all_labels, _actual_n_channels = load_or_preprocess_data(config) if not _all_data or _all_labels is None or _actual_n_channels is None: raise ValueError("Data loading returned invalid results.") PRELOADED_DATA = (_all_data, _all_labels, _actual_n_channels) logging.info(f"Data loaded successfully. Actual N_CHANNELS determined: {_actual_n_channels}") # Optionally update config if n_channels needs to be stored globally, but model init uses the value directly # config['fixed_params']['n_channels'] = _actual_n_channels except Exception as e: logging.error(f"Fatal error during data loading: {e}", exc_info=True) sys.exit(1) # --- 4. Configure Optuna Study --- hpo_cfg = config['hpo_settings'] study_name = hpo_cfg['study_name'] storage_name = hpo_cfg.get('storage') # None if not specified (in-memory) direction = hpo_cfg.get('direction', 'maximize') # Instantiate Pruner from config pruner_cfg = hpo_cfg.get('pruner') pruner = None if pruner_cfg and pruner_cfg.get('type'): try: pruner_class = getattr(optuna.pruners, pruner_cfg['type']) pruner_params = pruner_cfg.get('params', {}) pruner = pruner_class(**pruner_params) logging.info(f"Using pruner: {pruner_cfg['type']} with params {pruner_params}") except AttributeError: logging.error(f"Pruner type '{pruner_cfg['type']}' not found in optuna.pruners.") sys.exit(1) except Exception as e: logging.error(f"Error initializing pruner: {e}") sys.exit(1) else: logging.info("No pruner configured or pruner type missing.") logging.info(f"Creating/Loading Optuna study '{study_name}' (Storage: {storage_name or 'in-memory'})...") try: study = optuna.create_study( study_name=study_name, storage=storage_name, direction=direction, pruner=pruner, load_if_exists=True ) except Exception as e: logging.error(f"Error creating or loading Optuna study: {e}", exc_info=True) sys.exit(1) # --- 5. Run Optimization --- n_trials = hpo_cfg['n_trials'] n_jobs = hpo_cfg.get('n_jobs', 1) # Default to 1 job (sequential) timeout = hpo_cfg.get('timeout') # Optional timeout in seconds if n_jobs > 1 and device.type == 'cuda': logging.warning(f"n_jobs={n_jobs} requested with CUDA device. Optuna's default parallel execution might lead to GPU contention.") logging.warning("For multi-GPU parallelism, it's recommended to set n_jobs=1 in the config and launch this script multiple times,") logging.warning("each targeting a different GPU (e.g., using CUDA_VISIBLE_DEVICES=0, CUDA_VISIBLE_DEVICES=1, ...).") logging.warning("Ensure you are using a shared Optuna storage (e.g., SQLite or PostgreSQL) for parallel runs.") # Consider forcing n_jobs=1 or exiting if a single CUDA device is detected with n_jobs > 1 # if torch.cuda.device_count() == 1: # logging.info("Detected single GPU, forcing n_jobs=1 for stability.") # n_jobs = 1 logging.info(f"Starting Optuna optimization with {n_trials} trials...") logging.info(f" Parallel jobs (n_jobs): {n_jobs}") logging.info(f" Timeout: {timeout or 'Not set'}") logging.info(f" Device for this process: {device}") # Use functools.partial to pass fixed arguments (config, data, device) to the objective objective_with_args = functools.partial(objective, config=config, preloaded_data=PRELOADED_DATA, device=device) try: study.optimize( objective_with_args, n_trials=n_trials, n_jobs=n_jobs, # Pass n_jobs for potential parallelism timeout=timeout ) except Exception as e: logging.error(f"An error occurred during study optimization: {e}", exc_info=True) # The study might be partially completed. # --- 6. Report Results --- logging.info("\n--- HPO Finished ---") # Ensure study has trials before accessing attributes if study.trials: logging.info(f"Study statistics: ") logging.info(f" Number of finished trials: {len(study.trials)}") pruned_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.PRUNED])) complete_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.COMPLETE])) failed_trials = len(study.get_trials(deepcopy=False, states=[optuna.trial.TrialState.FAIL])) logging.info(f" Pruned trials: {pruned_trials}") logging.info(f" Complete trials: {complete_trials}") logging.info(f" Failed trials: {failed_trials}") try: best_trial = study.best_trial logging.info(f"Best trial:") logging.info(f" Number: {best_trial.number}") logging.info(f" Value (Best {direction} F1): {best_trial.value:.6f}") logging.info(f" Params (Suggested): ") for key, value in best_trial.params.items(): # Parameters suggested by Optuna logging.info(f" {key}: {value}") # Also log the combined HPO parameters used (including fixed ones like batch_size) if best_trial.user_attrs.get("hpo_params"): logging.info(f" User Attributes (Full HPO Params used):") for key, value in best_trial.user_attrs["hpo_params"].items(): logging.info(f" {key}: {value}") except ValueError: logging.warning("Could not determine the best trial (perhaps no trials completed successfully).") except Exception as e: logging.error(f"An error occurred retrieving best trial info: {e}") # Optional: Save study results dataframe try: results_df_path = f"{study_name}_results.csv" df = study.trials_dataframe() df.to_csv(results_df_path, index=False) logging.info(f"Study results saved to {results_df_path}") except Exception as e: logging.error(f"Failed to save study results dataframe: {e}") else: logging.warning("No trials were recorded in the study.") logging.info("HPO script finished.")

사고

Researching latest time series models

I’ll need to explore the current state of time series classification for multi-label, multi-class tasks in 2025. That means looking at new models like TSVit, RT-Transformer, and Perceiver IO, along with advancements related to retrieval-based transformers. I'll focus on finding efficient methods like TS-Mamba and state space models. I also plan to suggest improvements like enhanced hyperparameter search, PyTorch Lightning integration, and model upgrades such as hierarchical attention and multi-level temporal fusion. I’ll need sources to support these ideas.

Proposing improvements for multi-label classification

For the multi-label classification task, I’d recommend a few key changes. First, consider replacing TimeMixerPP with state-space models like TS-Mamba. It could also be valuable to incorporate foundation models or self-supervised methods such as TS2Vec for pretraining. Enhancing the data pipeline with jitter, scaling, and Mixup could improve performance too. Let's not forget advanced loss strategies, like hierarchical label smoothing, and using Optuna for optimized hyperparameter search. I’ll also suggest using strategies like early stopping, model architecture changes, and distributed training! Let’s gather relevant sources to back this up.

Thought for 26 seconds

답변

아래 제안들은 **“현재 TimeMixerPP + Optuna 파이프라인”**을 2025년 시점 SOTA로 끌어올리기 위한 아키텍처·훈련·실험 조직 전략입니다. 단계별로 필요한 모듈·레퍼런스와 함께 설명했으니, 맞는 것부터 하나씩 적용해 보세요.


1. Encoder를 “Foundation‑Model 정권”으로 교체·병합

방향핵심 아이디어적용 힌트
(a) TS‑Mamba/SSM 블록으로 전면 교체Mamba 기반 Selective‑SSM은 긴 시퀀스에서 Transformer 수준 성능을 선형 시간에 달성합니다. state‑spaces/mamba 구현을 가져와 TimeMixerPP_Model 대신 encoder 스택으로 사용하거나 Hybrid(↓) 로 병합. GitHub arXiv
(b) TimeMixerPP × Mamba Hybridcoarse‑scale는 TimeMixerPP, fine‑scale은 Mamba 블록 → nn.ModuleList에 서로 다른 block 넣고 skip‑connection으로 합산.
(c) 사전학습 Foundation Model 파인튜닝2025년 공개된 TS‑Mamba, Mamba4Cast, 또는 HuggingFace tsfm 체크포인트를 불러와 encoder weight를 고정 → 점진적 unfreeze. 1개 GPU에서도 적은 step만 fine‑tune해도 큰 이득. arXiv OpenReview

코드 스케치

python
from mamba_ssm import MambaBlock class MambaEncoder(nn.Module): def __init__(self, d_model, depth): super().__init__() self.blocks = nn.ModuleList( [MambaBlock(d_model) for _ in range(depth)]) def forward(self, x): for blk in self.blocks: x = blk(x) return x

TimeMixerPPClassifier 에서 self.timemixer_encoder를 위 모듈로 치환하거나, 두 encoder를 병렬→Concat 후 projector = nn.Linear(sum_dim, d_model)로 통일하세요.


2. 멀티‑스케일·뷰 Fusion 고도화

  1. 주파수‑도메인 뷰: TSCMamba 논문처럼 CWT/FFT 특성을 별도로 추출 후 channel‑wise cross‑attention으로 융합. arXiv arXiv
  2. 계층적 피라미드: seq_len을 {L, L/2, L/4, …}로 다운샘플링 후 각기 다른 encoder에 투입 → 마지막에 gated additive fusion(GatedFusion = σ(W₁·[h₀;h₁])+ (1-σ)…).
  3. 라벨-의존 GNN:  라벨 간 상·반의존성을 GCN/GraphAttention으로 학습시키고, MLP‑heads 입력에 FiLM 방식으로 주입 → label correlation 학습(특히 클래스가 불균형할 때 F1↑).

3. Loss & Metric 개선

내용구현 팁
Class‑Balanced Focal Loss (per‑task)alpha = effective‑num / Σeffective‑num; gamma=2.0
Label‑Aware Mixup/CutMix샘플 간 같은 라벨만 섞기 → 긍정·부정 비율 유지.
평가평균 Macro‑F1 외에 subset‑accuracy, micro‑F1 로깅; WandB table에 저장.

4. Optuna → 다중 목표 + 동적 검색 공간

  • MultiObjectiveStudy 로 Macro‑F1과 GPU RAM(Peak CUDA MB) 동시 최적화 → 실제 배포 모델 크기도 함께 탐색.
  • Dynamic Trial Pruning: 3 epoch마다 val_loss가 이전 best보다 10% 높으면 prune.
  • Search Space 예시
    • encoder_type ∈ {TimeMixerPP, Mamba, Hybrid}
    • d_model: [128‑768] log, depth: [4‑24], kernel_size
    • augmentation ⊂ {jitter, scaling, time_warp} (categorical multi‑choice)

5. 데이터 파이프라인 & Infra

개선세부 내용
Stratified GroupKFoldsubject_id 단위 그룹화를 유지해 leakage 방지.
TorchData/StreamingParquet → IterableDataset 로 lazy‑load, GPU‑host pin_memory↓
Mixed Precision & Gradient Checkpointing이미 scaler 사용 → encoder에 torch.utils.checkpoint 적용 시 VRAM 20‑40 %↓
SWA + EMA마지막 20 % epoch 동안 SWA 업데이트, EMA(0.999) 병행 → 일반화 향상.

6. 포지셔널/시계열 인코딩 최신화

최근 Survey에서 권장하는 직교다항(Relative) + ∆Time 결합 방식으로 변경해 Transformer류 성능 1‑2 포인트↑. arXiv

python
class LearnedRelPos(nn.Module): def __init__(self, L, d): super().__init__() self.rel = nn.Parameter(torch.randn(L, d)) def forward(self, T): return self.rel[:T]

Encoder 입력에 x + pos[:T] + delta_t_embed.


순서 추천

  1. Encoder 교체 (Mamba → Hybrid)
  2. Loss/Metric·데이터 폴드 안전망 구축
  3. HPO Search Space 확장 → MultiObjective
  4. Fusion·Label GNN→추가 실험

작은 단위로 실험 로그를 남기며 벽돌 쌓듯 진행하면, 전체 Macro‑F1이 5‑10 pt 정도 상승하는 것을 기대할 수 있습니다. 필요하면 특정 섹션별 코드 스니펫 더 드리겠습니다—편하게 말씀 주세요!

이 Q&A 공유하기