```# 导入所需的库 import pandas as pd import numpy as np...

创建于:2025年6月4日

使用 O3 回答 Chat01

问题

text
import pandas as pd import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader, IterableDataset from sklearn.preprocessing import MinMaxScaler from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score, confusion_matrix, classification_report import matplotlib.pyplot as plt import seaborn as sns import os import json import time import csv from tqdm import tqdm from fastapi import FastAPI, Request, HTTPException from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates from pydantic import BaseModel import uvicorn import logging import io import base64 from typing import List, Dict, Union, Tuple, Optional import gc import multiprocessing from functools import partial from concurrent.futures import ProcessPoolExecutor set_lstm_hidden_size=128 set_lstm_layers = 6 set_dropout=0.5 # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler("fault_prediction.log"), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 设置随机种子以确保可重现性 def set_seed(seed=42): """ 设置所有随机种子以确保结果可重现 参数: seed (int): 随机种子值 """ np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # 如果使用多GPU torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False set_seed() # 设备配置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f"使用设备: {device}") # 检查是否有多个GPU可用 if torch.cuda.device_count() > 1: logger.info(f"使用 {torch.cuda.device_count()} 个GPU进行训练") multi_gpu = True else: multi_gpu = False # 数据集定义 class FaultDataset(Dataset): """ 用于处理故障数据的自定义数据集类 """ def __init__(self, file_path: str = None, data: pd.DataFrame = None, normalize: bool = True, target_length: int = 150, alignment_method: str = 'linear_interpolation', augment: bool = False, scalers: dict = None): """ 初始化数据集 参数: file_path (str): CSV文件路径,可以为None如果直接提供dataframe data (pd.DataFrame): 直接提供的数据帧,可以为None如果提供file_path normalize (bool): 是否对数据进行归一化处理 target_length (int): 目标序列长度,用于对齐 alignment_method (str): 对齐方法,可选值为: 'linear_interpolation': 线性插值 'pad_zeros': 零填充 'dynamic_time_warping': 动态时间规整 'fourier_transform': 傅里叶变换重采样 augment (bool): 是否进行数据增强 scalers (dict): 已经拟合好的归一化器,用于测试或验证集 """ self.target_length = target_length self.normalize = normalize self.alignment_method = alignment_method self.augment = augment # 从文件加载数据或使用提供的数据帧 if file_path is not None: self.data = self.load_data(file_path) elif data is not None: self.data = data else: raise ValueError("必须提供file_path或data参数之一") # 分离特征和标签 self.features = self.data.drop(['data_id', 'measure_time', 'turnout_num', 'label'], axis=1) self.labels = self.data['label'].values # 存储原始曲线长度 self.original_lengths = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } # 处理和对齐曲线数据 self.processed_features = self.process_curves() # 数据归一化 if self.normalize: if scalers is None: self.scalers = self.fit_scalers() else: self.scalers = scalers # 应用归一化 self.normalized_features = self.apply_normalization() else: self.normalized_features = self.processed_features def load_data(self, file_path: str) -> pd.DataFrame: """ 从CSV文件加载数据,支持大文件分块处理以避免内存问题 参数: file_path (str): CSV文件路径 返回: pd.DataFrame: 加载的数据帧 """ # 检查文件是否存在 if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") logger.info(f"从 {file_path} 加载数据...") try: # 对于大文件,使用分块读取以避免内存问题 chunk_size = 10000 # 每次读取的行数 chunks = [] for chunk in tqdm(pd.read_csv(file_path, chunksize=chunk_size), desc="加载数据块"): chunks.append(chunk) data = pd.concat(chunks, ignore_index=True) logger.info(f"数据加载完成。形状: {data.shape}") return data except Exception as e: logger.error(f"加载数据时出错: {str(e)}") raise def parse_curve(self, curve_str: str) -> List[float]: """ 解析曲线字符串为数值列表 参数: curve_str (str): 曲线字符串,格式如 "|0.0|0.15|0.6|..." 返回: List[float]: 解析后的数值列表 """ # 去除首尾的'|'并按'|'分割 values = curve_str.strip('|').split('|') # 转换为浮点数 return [float(val) for val in values] def process_curves(self) -> Dict[str, np.ndarray]: """ 处理所有曲线,解析并对齐到目标长度 返回: Dict[str, np.ndarray]: 处理后的特征字典,键为相名,值为对齐后的数组 """ processed = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } # 向量化处理曲线数据 for phase in processed.keys(): # 预处理字符串 curves = self.features[phase].str.strip('|').str.split('|') # 转换为数值型列表 curves = curves.apply(lambda x: [float(val) for val in x]) # 记录原始长度 self.original_lengths[phase] = curves.apply(len).values # 批量对齐 aligned_curves = [] for curve in tqdm(curves, desc=f"对齐{phase}曲线"): aligned_curves.append(self.align_curve(curve, self.target_length)) processed[phase] = np.array(aligned_curves) logger.info(f"曲线处理完成。每相形状: {processed['A相'].shape}") return processed def align_curve(self, curve: List[float], target_length: int) -> np.ndarray: """ 将曲线对齐到目标长度 参数: curve (List[float]): 原始曲线数据 target_length (int): 目标长度 返回: np.ndarray: 对齐后的曲线 """ if len(curve) == target_length: return np.array(curve) # 根据选择的方法进行对齐 if self.alignment_method == 'linear_interpolation': # 线性插值对齐 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif self.alignment_method == 'pad_zeros': # 零填充对齐 if len(curve) > target_length: # 如果曲线长度大于目标长度,则截断 return np.array(curve[:target_length]) else: # 如果曲线长度小于目标长度,则填充零 padded = np.zeros(target_length) padded[:len(curve)] = curve return padded elif self.alignment_method == 'dynamic_time_warping': # 动态时间规整需要第三方库,这里只实现简化版本 scale_factor = target_length / len(curve) indices = np.round(np.arange(0, len(curve), 1/scale_factor)).astype(int) indices = indices[:target_length] return np.array([curve[min(i, len(curve)-1)] for i in indices]) elif self.alignment_method == 'fourier_transform': # 使用傅里叶变换进行重采样 try: from scipy import fftpack # 获取频域表示 curve_fft = fftpack.fft(curve) # 如果目标长度大于原始长度,补零 if target_length > len(curve): pad_size = target_length - len(curve) curve_fft_padded = np.pad(curve_fft, (0, pad_size), 'constant') # 如果目标长度小于原始长度,截断高频部分 else: curve_fft_padded = curve_fft[:target_length] # 逆变换回时域 return fftpack.ifft(curve_fft_padded).real except ImportError: logger.warning("scipy未安装,回退到线性插值") # 回退到线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif self.alignment_method == 'jump_point_alignment': return self.align_by_jump_point(curve, target_length) else: # 默认使用线性插值 logger.warning(f"未知的对齐方法: {self.alignment_method},使用线性插值") original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) def fit_scalers(self) -> Dict[str, Dict[str, float]]: """ 为每个相拟合归一化器,优化内存使用 返回: Dict[str, Dict[str, float]]: 每个相的归一化参数 """ scalers = {} for phase in self.processed_features.keys(): data = self.processed_features[phase] min_val = np.min(data) max_val = np.max(data) scalers[phase] = {'min': min_val, 'max': max_val} logger.info("归一化器拟合完成") return scalers def apply_normalization(self) -> Dict[str, np.ndarray]: """ 应用归一化到处理后的特征,优化内存使用 返回: Dict[str, np.ndarray]: 归一化后的特征 """ normalized = {} for phase in self.processed_features.keys(): data = self.processed_features[phase] min_val = self.scalers[phase]['min'] max_val = self.scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: normalized[phase] = np.zeros_like(data) else: normalized[phase] = (data - min_val) / range_val logger.info("特征归一化完成") return normalized def augment_data(self, data: Dict[str, np.ndarray], label: int) -> Tuple[Dict[str, np.ndarray], int]: """ 对数据进行增强 参数: data (Dict[str, np.ndarray]): 要增强的数据 label (int): 对应的标签 返回: Tuple[Dict[str, np.ndarray], int]: 增强后的数据和标签 """ # 只有在训练阶段才进行数据增强 if not self.augment: return data, label # 随机选择增强方法 augmentation_choice = np.random.rand() # 对数据的一个副本进行操作 augmented_data = {} for phase in data.keys(): augmented_data[phase] = data[phase].copy() if augmentation_choice < 0.25: # 方法1: 添加随机噪声 noise_level = np.random.uniform(0.01, 0.03) for phase in augmented_data.keys(): noise = np.random.normal(0, noise_level, size=augmented_data[phase].shape) augmented_data[phase] = augmented_data[phase] + noise elif augmentation_choice < 0.5: # 方法2: 时间滑动(左右移动) shift_amount = int(np.random.uniform(-10, 10)) for phase in augmented_data.keys(): if shift_amount > 0: augmented_data[phase] = np.pad(augmented_data[phase], (0, shift_amount), 'constant')[:self.target_length] else: augmented_data[phase] = np.pad(augmented_data[phase], (-shift_amount, 0), 'constant')[:-shift_amount if shift_amount != 0 else None] elif augmentation_choice < 0.75: # 方法3: 幅度缩放 scale_factor = np.random.uniform(0.9, 1.1) for phase in augmented_data.keys(): augmented_data[phase] = augmented_data[phase] * scale_factor else: # 方法4: 反转部分片段 if self.target_length > 10: # 随机选择一个片段进行反转 start_idx = np.random.randint(0, self.target_length - 10) end_idx = np.random.randint(start_idx + 5, min(start_idx + 10, self.target_length)) for phase in augmented_data.keys(): segment = augmented_data[phase][start_idx:end_idx].copy() augmented_data[phase][start_idx:end_idx] = segment[::-1] # 反转片段 return augmented_data, label def __len__(self): """返回数据集的大小""" return len(self.labels) def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], int]: """ 获取指定索引的样本 参数: idx (int): 样本索引 返回: Tuple[Dict[str, torch.Tensor], int]: 特征字典和标签 """ # 获取所有相的数据 sample = { 'A相': self.normalized_features['A相'][idx], 'B相': self.normalized_features['B相'][idx], 'C相': self.normalized_features['C相'][idx], 'P相': self.normalized_features['P相'][idx], } label = self.labels[idx] # 数据增强 if self.augment and np.random.rand() < 0.5: # 50%的概率进行增强 sample, label = self.augment_data(sample, label) # 转换为张量 tensor_sample = {phase: torch.FloatTensor(sample[phase]) for phase in sample.keys()} return tensor_sample, label def get_curve_length_stats(self) -> Dict[str, Dict[str, Union[float, int]]]: """ 获取曲线长度的统计信息 返回: Dict[str, Dict[str, Union[float, int]]]: 每个相的长度统计信息 """ stats = {} for phase in self.original_lengths.keys(): lengths = self.original_lengths[phase] stats[phase] = { 'min': np.min(lengths), 'max': np.max(lengths), 'mean': np.mean(lengths), 'median': np.median(lengths), 'std': np.std(lengths) } return stats def get_class_distribution(self) -> Dict[int, int]: """ 获取类别分布 返回: Dict[int, int]: 类别分布字典,键为类别,值为样本数 """ unique, counts = np.unique(self.labels, return_counts=True) return dict(zip(unique, counts)) @classmethod def create_train_val_test_sets(cls, file_path: str, train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15, target_length: int = 150, alignment_method: str = 'linear_interpolation', stratify: bool = True, balance_method: Optional[str] = None, compute_class_weights: bool = False, random_state: int = 42) -> Tuple[Dataset, Dataset, Dataset, Optional[torch.Tensor]]: """ 创建训练、验证和测试集 参数: file_path (str): CSV文件路径 train_ratio (float): 训练集比例 val_ratio (float): 验证集比例 test_ratio (float): 测试集比例 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 stratify (bool): 是否进行分层抽样 balance_method (str, 可选): 处理不平衡的方法,可选值: 'undersample': 欠采样 'oversample': 过采样 None: 不进行平衡处理 compute_class_weights (bool): 是否计算类别权重 random_state (int): 随机种子 返回: Tuple[Dataset, Dataset, Dataset, Optional[torch.Tensor]]: 训练集、验证集、测试集和类别权重 """ # 加载数据 temp_dataset = cls(file_path=file_path, normalize=False) # 获取标签分布 class_distribution = temp_dataset.get_class_distribution() logger.info(f"类别分布: {class_distribution}") # 创建数据帧,包含所有数据 df = temp_dataset.data # 分层划分数据集 if stratify: stratify_labels = df['label'] else: stratify_labels = None # 首先划分出测试集 train_val_df, test_df = train_test_split( df, test_size=test_ratio, random_state=random_state, stratify=stratify_labels ) # 然后从剩余数据中划分出验证集 if stratify: stratify_labels = train_val_df['label'] else: stratify_labels = None train_df, val_df = train_test_split( train_val_df, test_size=val_ratio/(train_ratio+val_ratio), random_state=random_state, stratify=stratify_labels ) logger.info(f"划分完成。训练集: {len(train_df)},验证集: {len(val_df)},测试集: {len(test_df)}") # 处理不平衡数据 if balance_method == 'undersample': # 欠采样:减少多数类的样本 train_df = cls._undersample(train_df) logger.info(f"欠采样后的训练集大小: {len(train_df)}") elif balance_method == 'oversample': # 过采样:增加少数类的样本 train_df = cls._oversample(train_df) logger.info(f"过采样后的训练集大小: {len(train_df)}") # 创建完整的数据集 # 首先创建训练集并拟合归一化器 train_dataset = cls( data=train_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=True ) # 然后使用训练集的归一化器创建验证集和测试集 val_dataset = cls( data=val_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=False, scalers=train_dataset.scalers ) test_dataset = cls( data=test_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=False, scalers=train_dataset.scalers ) # 计算类别权重 class_weights = None if compute_class_weights: class_distribution = train_dataset.get_class_distribution() num_samples = len(train_dataset) num_classes = len(class_distribution) class_weights = torch.zeros(num_classes) for class_idx, class_count in class_distribution.items(): class_weights[class_idx] = num_samples / (num_classes * class_count) logger.info(f"类别权重: {class_weights}") return train_dataset, val_dataset, test_dataset, class_weights @staticmethod def _undersample(df: pd.DataFrame) -> pd.DataFrame: """ 对数据帧进行欠采样以平衡类别 参数: df (pd.DataFrame): 要处理的数据帧 返回: pd.DataFrame: 欠采样后的数据帧 """ class_counts = df['label'].value_counts() min_count = class_counts.min() # 对每个类别进行欠采样 balanced_dfs = [] for class_label in class_counts.index: class_df = df[df['label'] == class_label] # 如果类别样本数大于最小数量,随机采样 if len(class_df) > min_count: sampled_df = class_df.sample(min_count, random_state=42) balanced_dfs.append(sampled_df) else: balanced_dfs.append(class_df) # 合并所有采样后的数据 return pd.concat(balanced_dfs, ignore_index=True) @staticmethod def _oversample(df: pd.DataFrame) -> pd.DataFrame: """ 对数据帧进行过采样以平衡类别 参数: df (pd.DataFrame): 要处理的数据帧 返回: pd.DataFrame: 过采样后的数据帧 """ class_counts = df['label'].value_counts() max_count = class_counts.max() # 对每个类别进行过采样 balanced_dfs = [] for class_label in class_counts.index: class_df = df[df['label'] == class_label] # 如果类别样本数小于最大数量,随机重复采样 if len(class_df) < max_count: # 计算需要重复的次数 n_samples = max_count - len(class_df) # 使用替换采样 extra_samples = class_df.sample(n_samples, replace=True, random_state=42) balanced_dfs.append(pd.concat([class_df, extra_samples], ignore_index=True)) else: balanced_dfs.append(class_df) # 合并所有采样后的数据 return pd.concat(balanced_dfs, ignore_index=True) def visualization_sample(self, idx: int, save_path: Optional[str] = None) -> None: """ 可视化指定索引的样本 参数: idx (int): 样本索引 save_path (str, 可选): 保存路径,不提供则显示图像 """ sample, label = self[idx] plt.figure(figsize=(12, 8)) # 创建一个2x2的子图布局 phases = ['A相', 'B相', 'C相', 'P相'] for i, phase in enumerate(phases): plt.subplot(2, 2, i+1) plt.plot(sample[phase].numpy()) plt.title(f"{phase} 曲线") plt.xlabel("时间步") plt.ylabel("值") plt.grid(True) plt.suptitle(f"样本 {idx}, 标签: {label}") plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() def find_jump_point(self, curve: np.ndarray, threshold: float = 0.5) -> int: """ 寻找曲线中的跳变点 参数: curve: 曲线数据 threshold: 跳变阈值(百分比变化) 返回: int: 跳变点索引 """ # 计算差分 diffs = np.abs(np.diff(curve)) # 计算均值和标准差 mean_diff = np.mean(diffs) std_diff = np.std(diffs) # 设定跳变阈值(可调整) jump_threshold = mean_diff + threshold * std_diff # 找出超过阈值的点 jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: # 返回最显著的跳变点 return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: # 没有显著跳变点,返回中点 return len(curve) // 2 def align_by_jump_point(self, curve: List[float], target_length: int) -> np.ndarray: """ 以跳变点为中心对齐曲线 参数: curve: 原始曲线 target_length: 目标长度 返回: 对齐后的曲线 """ curve_array = np.array(curve) # 找到跳变点 jump_idx = self.find_jump_point(curve_array) # 计算对齐后跳变点位置(例如,放在30%处) target_jump_idx = int(target_length * 0.3) # 计算需要的填充 left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) # 如果原始曲线太长,需要截断 if len(curve) + left_pad + right_pad > target_length: # 计算截断量 trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array # 填充 padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded class StreamingFaultDataset(IterableDataset): """ 用于流式处理大数据集的自定义数据集类,避免一次性加载所有数据到内存 """ def __init__(self, file_path: str, assigned_indices_list: List[int], batch_size: int = 32, target_length: int = 150, alignment_method: str = 'jump_point_alignment', is_train: bool = True, scalers: dict = None, buffer_size: int = 1000): """ 初始化流式数据集 参数: file_path (str): CSV文件路径 batch_size (int): 批次大小 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 is_train (bool): 是否为训练集 scalers (dict): 归一化参数 buffer_size (int): 缓冲区大小,用于加载和处理数据 train_ratio (float): 训练集比例 val_ratio (float): 验证集比例 test_ratio (float): 测试集比例 data_split (str): 数据集划分,'train', 'val', 或 'test' """ self.file_path = file_path self.assigned_indices = set(assigned_indices_list) self.num_samples_for_this_split = len(self.assigned_indices) self.batch_size = batch_size self.target_length = target_length self.alignment_method = alignment_method self.is_train = is_train self.scalers = scalers self.buffer_size = buffer_size logger.info(f"StreamingFaultDataset 初始化完成。将处理 {self.num_samples_for_this_split} 个样本。") def _find_jump_point(self, curve: np.ndarray, threshold: float = 0.5) -> int: """寻找曲线中的跳变点""" diffs = np.abs(np.diff(curve)) if len(diffs) == 0: return 0 mean_diff = np.mean(diffs) std_diff = np.std(diffs) jump_threshold = mean_diff + threshold * std_diff jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: return len(curve) // 2 def _align_by_jump_point(self, curve: List[float], target_length: int) -> np.ndarray: """以跳变点为中心对齐曲线""" curve_array = np.array(curve) jump_idx = self._find_jump_point(curve_array) target_jump_idx = int(target_length * 0.3) left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) if len(curve) + left_pad + right_pad > target_length: trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded def __iter__(self): """迭代器方法,按照数据集划分逐行读取文件并生成批次""" # 用于记录当前处理的行号 current_line = 0 # 缓冲区 buffer_features = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } buffer_labels = [] # 打开文件 with open(self.file_path, 'r', encoding='utf-8') as f: # 跳过标题行 next(f) # 逐行读取 for line_idx, line in enumerate(f): if line_idx not in self.assigned_indices: continue # 如果当前行不属于这个数据集划分,则跳过 try: # 解析行 row = line.strip().split(',') label = int(row[3]) # 解析和对齐四个相的曲线 a_phase = self._parse_and_align(row[4], 'A相') b_phase = self._parse_and_align(row[5], 'B相') c_phase = self._parse_and_align(row[6], 'C相') p_phase = self._parse_and_align(row[7], 'P相') # 添加到缓冲区 buffer_features['A相'].append(a_phase) buffer_features['B相'].append(b_phase) buffer_features['C相'].append(c_phase) buffer_features['P相'].append(p_phase) buffer_labels.append(label) # 当缓冲区满时,生成批次 if len(buffer_labels) >= self.batch_size: # 转换为张量 tensor_features = { phase: torch.FloatTensor(buffer_features[phase][:self.batch_size]) for phase in buffer_features } tensor_labels = torch.LongTensor(buffer_labels[:self.batch_size]) # 清空缓冲区,保留剩余样本 for phase in buffer_features: buffer_features[phase] = buffer_features[phase][self.batch_size:] buffer_labels = buffer_labels[self.batch_size:] # 如果是训练集且需要数据增强 if self.is_train and np.random.rand() < 0.3: # 30%概率执行增强 tensor_features = self._augment_batch(tensor_features) yield tensor_features, tensor_labels except Exception as e: logger.warning(f"跳过行 {line_idx}: {e}") # 处理剩余数据 if len(buffer_labels) > 0: # 修改:确保只要缓冲区有数据就处理 tensor_features = { phase: torch.FloatTensor(buffer_features[phase]) # 使用所有剩余数据 for phase in buffer_features } tensor_labels = torch.LongTensor(buffer_labels) # 使用所有剩余数据 if self.is_train and np.random.rand() < 0.3: tensor_features = self._augment_batch(tensor_features) yield tensor_features, tensor_labels def _parse_and_align(self, curve_str: str, phase: str) -> np.ndarray: """解析曲线字符串并对齐到目标长度""" # 解析曲线 values = curve_str.strip('|').split('|') curve = [float(val) for val in values] # 对齐曲线 if self.alignment_method == 'jump_point_alignment': aligned = self._align_by_jump_point(curve, self.target_length) else: # 默认使用线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, self.target_length) aligned = np.interp(new_indices, original_indices, curve) # 归一化 if self.scalers is not None: min_val = self.scalers[phase]['min'] max_val = self.scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: normalized = np.zeros_like(aligned) else: normalized = (aligned - min_val) / range_val return normalized return aligned def _augment_batch(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """对批次数据进行增强""" augmentation_choice = np.random.rand() if augmentation_choice < 0.25: # 添加随机噪声 noise_level = np.random.uniform(0.01, 0.03) for phase in features: noise = torch.randn_like(features[phase]) * noise_level features[phase] = features[phase] + noise elif augmentation_choice < 0.5: # 时间滑动 shift_amount = int(np.random.uniform(-5, 5)) for phase in features: if shift_amount > 0: features[phase] = torch.cat([ torch.zeros(features[phase].size(0), shift_amount, device=features[phase].device), features[phase][:, :-shift_amount] ], dim=1) elif shift_amount < 0: features[phase] = torch.cat([ features[phase][:, -shift_amount:], torch.zeros(features[phase].size(0), -shift_amount, device=features[phase].device) ], dim=1) elif augmentation_choice < 0.75: # 幅度缩放 scale_factor = np.random.uniform(0.9, 1.1) for phase in features: features[phase] = features[phase] * scale_factor else: # 反转部分片段 if self.target_length > 10: start_idx = np.random.randint(0, self.target_length - 10) end_idx = np.random.randint(start_idx + 5, min(start_idx + 10, self.target_length)) for phase in features: segment = features[phase].clone()[:, start_idx:end_idx] features[phase][:, start_idx:end_idx] = torch.flip(segment, [1]) return features # 定义数据集的collate_fn,用于批处理 def collate_fn(batch): """ 自定义collate函数,处理不同长度的序列 参数: batch: 批次数据 返回: features_batch: 批次特征 labels_batch: 批次标签 """ # 分离特征和标签 features = [item[0] for item in batch] labels = [item[1] for item in batch] # 处理每个相的数据 features_batch = { 'A相': torch.stack([sample['A相'] for sample in features]), 'B相': torch.stack([sample['B相'] for sample in features]), 'C相': torch.stack([sample['C相'] for sample in features]), 'P相': torch.stack([sample['P相'] for sample in features]) } # 转换标签为张量 labels_batch = torch.LongTensor(labels) return features_batch, labels_batch # 模型定义 class AttentionLayer(nn.Module): """ 注意力层,用于关注序列中的重要部分 """ def __init__(self, hidden_size: int): """ 初始化注意力层 参数: hidden_size (int): 隐藏层大小 """ super(AttentionLayer, self).__init__() # 注意力计算所需的线性层 self.attn = nn.Linear(hidden_size * 2, hidden_size) self.v = nn.Parameter(torch.rand(hidden_size)) # 初始化参数 stdv = 1. / (hidden_size ** 0.5) self.v.data.uniform_(-stdv, stdv) def forward(self, hidden: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor: """ 前向传播 参数: hidden (torch.Tensor): 隐藏状态 [batch_size, 1, hidden_size] encoder_outputs (torch.Tensor): 编码器输出 [batch_size, seq_len, hidden_size] 返回: torch.Tensor: 注意力加权后的上下文向量 [batch_size, 1, seq_len] """ batch_size = encoder_outputs.size(0) seq_len = encoder_outputs.size(1) # 将hidden复制seq_len次,使其与encoder_outputs的序列长度匹配 # 从 [batch_size, 1, hidden_size] 变为 [batch_size, seq_len, hidden_size] hidden_expanded = hidden.repeat(1, seq_len, 1) # 拼接hidden和encoder_outputs # [batch_size, seq_len, hidden_size*2] combined = torch.cat([hidden_expanded, encoder_outputs], dim=2) # 计算注意力能量 # [batch_size, seq_len, hidden_size] energy = torch.tanh(self.attn(combined)) # 计算注意力权重 # [batch_size, seq_len] attn_energies = torch.sum(self.v * energy, dim=2) # 归一化注意力权重 # [batch_size, 1, seq_len] return F.softmax(attn_energies, dim=1).unsqueeze(1) def score(self, hidden: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor: """ 计算注意力得分 - 此方法已不再使用,保留为了兼容性 参数: hidden (torch.Tensor): 隐藏状态 encoder_outputs (torch.Tensor): 编码器输出 返回: torch.Tensor: 注意力得分 """ return self.forward(hidden, encoder_outputs).squeeze(1) class ResidualBlock(nn.Module): """ 残差块,用于构建深层CNN """ def __init__(self, in_channels: int, out_channels: int): """ 初始化残差块 参数: in_channels (int): 输入通道数 out_channels (int): 输出通道数 """ super(ResidualBlock, self).__init__() # 第一个卷积层 self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(out_channels) # 第二个卷积层 self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(out_channels) # 如果输入和输出通道数不同,需要一个1x1卷积进行调整 self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=1), nn.BatchNorm1d(out_channels) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 参数: x (torch.Tensor): 输入张量 返回: torch.Tensor: 残差块的输出 """ # 保存原始输入用于残差连接 residual = self.shortcut(x) # 卷积路径 out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) # 添加残差连接并应用ReLU out += residual out = F.relu(out) return out class LSTMCNN(nn.Module): """ 结合CNN和LSTM的模型,用于处理时序数据 """ def __init__(self, input_channels: int = 4, seq_length: int = 150, num_classes: int = 9, cnn_out_channels: int = 128, lstm_hidden_size: int = 128, lstm_layers: int = 2, dropout: float = 0.5, bidirectional: bool = True): """ 初始化LSTMCNN模型 参数: input_channels (int): 输入通道数,对应4个相 seq_length (int): 序列长度 num_classes (int): 分类类别数 cnn_out_channels (int): CNN输出通道数 lstm_hidden_size (int): LSTM隐藏层大小 lstm_layers (int): LSTM层数 dropout (float): Dropout比率 bidirectional (bool): LSTM是否双向 """ super(LSTMCNN, self).__init__() self.seq_length = seq_length self.hidden_size = lstm_hidden_size self.num_layers = lstm_layers self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1 # CNN特征提取器 - 每个相单独处理 self.cnn_a = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_b = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_c = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_p = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) # 计算CNN后的序列长度 cnn_out_length = seq_length // 8 # 三次MaxPool1d(2)后长度减为原来的1/8 # LSTM层 self.lstm = nn.LSTM( input_size=cnn_out_channels * input_channels, # 4个相的CNN特征拼接 hidden_size=lstm_hidden_size, num_layers=lstm_layers, batch_first=True, dropout=dropout if lstm_layers > 1 else 0, bidirectional=bidirectional ) # 注意力层 self.attention = AttentionLayer(lstm_hidden_size * self.num_directions) # 分类器 self.classifier = nn.Sequential( nn.Linear(lstm_hidden_size * self.num_directions, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes) ) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: """ 前向传播 参数: x (Dict[str, torch.Tensor]): 输入特征字典,包含四个相的数据 返回: torch.Tensor: 模型输出(类别预测) """ batch_size = x['A相'].size(0) # 每个相的数据调整为CNN输入形状 [B, 1, L] x_a = x['A相'].unsqueeze(1) # [B, 1, L] x_b = x['B相'].unsqueeze(1) # [B, 1, L] x_c = x['C相'].unsqueeze(1) # [B, 1, L] x_p = x['P相'].unsqueeze(1) # [B, 1, L] # 通过CNN提取特征 cnn_out_a = self.cnn_a(x_a) # [B, C, L//8] cnn_out_b = self.cnn_b(x_b) # [B, C, L//8] cnn_out_c = self.cnn_c(x_c) # [B, C, L//8] cnn_out_p = self.cnn_p(x_p) # [B, C, L//8] # 调整形状以便于拼接,[B, C, L//8] -> [B, L//8, C] cnn_out_a = cnn_out_a.transpose(1, 2) cnn_out_b = cnn_out_b.transpose(1, 2) cnn_out_c = cnn_out_c.transpose(1, 2) cnn_out_p = cnn_out_p.transpose(1, 2) # 拼接四个相的特征,在特征维度上合并 cnn_out_combined = torch.cat([cnn_out_a, cnn_out_b, cnn_out_c, cnn_out_p], dim=2) # 通过LSTM处理时序特征 lstm_out, (hidden, _) = self.lstm(cnn_out_combined) # 使用注意力机制处理LSTM输出 if self.bidirectional: # 双向LSTM,需要拼接最后一层的前向和后向 hidden = torch.cat([hidden[-2], hidden[-1]], dim=1).unsqueeze(1) else: # 单向LSTM,只使用最后一层 hidden = hidden[-1].unsqueeze(1) attn_weights = self.attention(hidden, lstm_out) context = torch.bmm(attn_weights, lstm_out) context = context.squeeze(1) # 全连接分类 output = self.classifier(context) return output # 训练和评估函数 def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, num_epochs: int = 100, learning_rate: float = 0.001, weight_decay: float = 0.01, patience: int = 10, checkpoint_path: str = 'best_model.pth', class_weights: Optional[torch.Tensor] = None, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Dict[str, List]: """ 训练模型 参数: model (nn.Module): 要训练的模型 train_loader (DataLoader): 训练数据加载器 val_loader (DataLoader): 验证数据加载器 num_epochs (int): 训练轮数 learning_rate (float): 学习率 weight_decay (float): 权重衰减系数 patience (int): 早停耐心值 checkpoint_path (str): 模型检查点保存路径 class_weights (torch.Tensor, 可选): 类别权重 device (torch.device): 训练设备 返回: Dict[str, List]: 包含训练历史的字典 """ logger.info(f"开始训练,设备: {device}") # 如果有多GPU,使用DataParallel if torch.cuda.device_count() > 1 and isinstance(device, torch.device) and device.type == 'cuda': logger.info(f"使用 {torch.cuda.device_count()} 个GPU进行训练") model = nn.DataParallel(model) model = model.to(device) # 尝试加载检查点 start_epoch = 0 best_val_f1 = 0.0 early_stop_counter = 0 if os.path.exists(checkpoint_path): try: logger.info(f"从 {checkpoint_path} 加载检查点...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if isinstance(model, nn.DataParallel): model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) optimizer_state_dict = checkpoint.get('optimizer_state_dict', None) scheduler_state_dict = checkpoint.get('scheduler_state_dict', None) start_epoch = checkpoint.get('epoch', -1) + 1 best_val_f1 = checkpoint.get('best_val_f1', 0.0) early_stop_counter = checkpoint.get('early_stop_counter', 0) logger.info(f"检查点加载成功。从 Epoch {start_epoch+1} 继续训练。") logger.info(f"恢复的最佳验证F1: {best_val_f1:.4f}, Early stop counter: {early_stop_counter}") except Exception as e: logger.error(f"加载检查点失败: {e}。将从头开始训练。") start_epoch = 0 best_val_f1 = 0.0 early_stop_counter = 0 else: logger.info("未找到检查点,将从头开始训练。") # 定义损失函数和优化器 if class_weights is not None: class_weights = class_weights.to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) logger.info("使用带权重的交叉熵损失") else: criterion = nn.CrossEntropyLoss() logger.info("使用标准交叉熵损失") optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5) # 如果加载了检查点,恢复优化器和调度器状态 if 'optimizer_state_dict' in locals() and optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) logger.info("优化器状态已恢复") if 'scheduler_state_dict' in locals() and scheduler_state_dict is not None: scheduler.load_state_dict(scheduler_state_dict) logger.info("学习率调度器状态已恢复") # 训练历史 history = { 'train_loss': [], 'train_acc': [], 'train_f1': [], 'val_loss': [], 'val_acc': [], 'val_f1': [] } # 早停机制 # best_val_f1 已经从检查点恢复或初始化为0.0 # early_stop_counter 已经从检查点恢复或初始化为0 # 开始训练循环 for epoch in range(start_epoch, num_epochs): model.train() train_loss = 0.0 train_correct = 0 train_total = 0 train_pred_all = [] train_true_all = [] batch_count = 0 # 计数处理的批次数 # 训练阶段 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]") for batch_idx, (features, targets) in enumerate(progress_bar): batch_count += 1 # 更新批次计数 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 清除梯度 optimizer.zero_grad() # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 反向传播和优化 loss.backward() optimizer.step() # 统计 train_loss += loss.item() _, predicted = outputs.max(1) train_correct += predicted.eq(targets).sum().item() train_total += targets.size(0) # 收集预测和真实标签用于计算F1 train_pred_all.extend(predicted.cpu().numpy()) train_true_all.extend(targets.cpu().numpy()) # 更新进度条 - 使用当前批次平均值 current_train_loss_avg = train_loss / batch_count if batch_count > 0 else 0.0 current_train_acc_avg = train_correct / train_total if train_total > 0 else 0.0 progress_bar.set_postfix(loss=current_train_loss_avg, acc=current_train_acc_avg) # 计算训练F1分数 train_f1 = f1_score(train_true_all, train_pred_all, average='macro', zero_division=0) # 验证阶段 model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 val_pred_all = [] val_true_all = [] val_batch_count = 0 # 验证集批次计数 with torch.no_grad(): progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]") for batch_idx, (features, targets) in enumerate(progress_bar): val_batch_count += 1 # 更新验证集批次计数 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 统计 val_loss += loss.item() _, predicted = outputs.max(1) val_correct += predicted.eq(targets).sum().item() val_total += targets.size(0) # 收集预测和真实标签 val_pred_all.extend(predicted.cpu().numpy()) val_true_all.extend(targets.cpu().numpy()) # 更新进度条 - 使用当前批次平均值 current_val_loss_avg = val_loss / val_batch_count if val_batch_count > 0 else 0.0 current_val_acc_avg = val_correct / val_total if val_total > 0 else 0.0 progress_bar.set_postfix(loss=current_val_loss_avg, acc=current_val_acc_avg) # 计算验证F1分数 val_f1 = f1_score(val_true_all, val_pred_all, average='macro', zero_division=0) # 更新学习率 scheduler.step(val_f1) # 保存历史 - 使用批次计数而不是len(dataloader) # 添加除零保护 avg_train_loss = train_loss / batch_count if batch_count > 0 else 0.0 avg_val_loss = val_loss / val_batch_count if val_batch_count > 0 else 0.0 avg_train_acc = train_correct / train_total if train_total > 0 else 0.0 avg_val_acc = val_correct / val_total if val_total > 0 else 0.0 if val_batch_count == 0: logger.warning(f"Epoch {epoch+1}: 验证集为空或未产生批次,验证损失和准确率将为0。") history['train_loss'].append(avg_train_loss) history['train_acc'].append(avg_train_acc) history['train_f1'].append(train_f1) history['val_loss'].append(avg_val_loss) history['val_acc'].append(avg_val_acc) history['val_f1'].append(val_f1) # 计算每个类别的F1分数 train_f1_per_class = f1_score(train_true_all, train_pred_all, average=None, zero_division=0) val_f1_per_class = f1_score(val_true_all, val_pred_all, average=None, zero_division=0) # 输出本轮结果 logger.info(f"Epoch {epoch+1}/{num_epochs} - " f"Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, Train F1: {train_f1:.4f}, " f"Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}, Val F1: {val_f1:.4f}") # 输出每个类别的F1分数 logger.info(f"Train F1 per class: {', '.join([f'{f:.4f}' for f in train_f1_per_class])}") logger.info(f"Val F1 per class: {', '.join([f'{f:.4f}' for f in val_f1_per_class])}") # 检查是否是最佳模型 if val_f1 > best_val_f1: best_val_f1 = val_f1 # 保存模型和训练状态 checkpoint = { 'epoch': epoch, 'model_state_dict': model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_f1': best_val_f1, 'early_stop_counter': early_stop_counter, 'class_weights': class_weights, } torch.save(checkpoint, checkpoint_path) logger.info(f"检查点已保存到 {checkpoint_path},Epoch: {epoch+1}, 验证F1: {val_f1:.4f}") early_stop_counter = 0 else: early_stop_counter += 1 logger.info(f"验证F1未提升,当前计数: {early_stop_counter}/{patience}") # 早停检查 if early_stop_counter >= patience: logger.info(f"早停触发,停止训练。最佳验证F1: {best_val_f1:.4f}") break return history def evaluate_model(model: nn.Module, data_loader: DataLoader, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'), output_dir: Optional[str] = None) -> Dict[str, Union[float, np.ndarray]]: """ 评估模型 参数: model (nn.Module): 要评估的模型 data_loader (DataLoader): 数据加载器 device (torch.device): 计算设备 output_dir (str, 可选): 输出目录,如果提供,会保存混淆矩阵图 返回: Dict[str, Union[float, np.ndarray]]: 评估结果 """ logger.info("开始评估模型") model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 all_preds = [] all_targets = [] batch_count = 0 with torch.no_grad(): for features, targets in tqdm(data_loader, desc="评估"): batch_count += 1 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 统计 total_loss += loss.item() _, predicted = outputs.max(1) # 收集预测和真实标签 all_preds.extend(predicted.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) # 计算平均损失 - 使用批次计数 avg_loss = total_loss / batch_count if batch_count > 0 else 0.0 if batch_count == 0: logger.warning("评估时数据加载器为空或未产生批次。") # 计算准确率 accuracy = sum(np.array(all_preds) == np.array(all_targets)) / len(all_targets) if len(all_targets) > 0 else 0.0 # 计算F1分数 f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0) # 计算混淆矩阵 cm = confusion_matrix(all_targets, all_preds) # 如果提供了输出目录,保存混淆矩阵图 if output_dir: if not os.path.exists(output_dir): os.makedirs(output_dir) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') plt.savefig(os.path.join(output_dir, 'confusion_matrix.png')) plt.close() # 保存分类报告 report_dict = classification_report(all_targets, all_preds, output_dict=True, zero_division=0) with open(os.path.join(output_dir, 'classification_report.json'), 'w') as f: json.dump(report_dict, f, indent=4) report_str = classification_report(all_targets, all_preds, zero_division=0) logger.info(f"分类报告:\n{report_str}") logger.info(f"评估结果 - 损失: {avg_loss:.4f}, 准确率: {accuracy:.4f}, F1分数: {f1:.4f}") return { 'loss': avg_loss, 'accuracy': accuracy, 'f1': f1, 'confusion_matrix': cm, 'predictions': np.array(all_preds), 'targets': np.array(all_targets) } def plot_training_history(history: Dict[str, List], save_path: Optional[str] = None): """ 绘制训练历史 参数: history (Dict[str, List]): 训练历史 save_path (str, 可选): 保存路径,不提供则显示图像 """ plt.figure(figsize=(15, 10)) # 绘制损失 plt.subplot(2, 2, 1) plt.plot(history['train_loss'], label='训练损失') plt.plot(history['val_loss'], label='验证损失') plt.title('损失') plt.xlabel('轮次') plt.ylabel('损失') plt.legend() plt.grid(True) # 绘制准确率 plt.subplot(2, 2, 2) plt.plot(history['train_acc'], label='训练准确率') plt.plot(history['val_acc'], label='验证准确率') plt.title('准确率') plt.xlabel('轮次') plt.ylabel('准确率') plt.legend() plt.grid(True) # 绘制F1分数 plt.subplot(2, 2, 3) plt.plot(history['train_f1'], label='训练F1') plt.plot(history['val_f1'], label='验证F1') plt.title('F1分数') plt.xlabel('轮次') plt.ylabel('F1') plt.legend() plt.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() # 模型预测函数 def predict_sample(model: nn.Module, a_phase: str, b_phase: str, c_phase: str, p_phase: str, target_length: int = 150, alignment_method: str = 'linear_interpolation', scalers: Dict[str, Dict[str, float]] = None, device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Dict[str, Union[int, float]]: """ 预测单个样本 参数: model (nn.Module): 模型 a_phase (str): A相曲线字符串 b_phase (str): B相曲线字符串 c_phase (str): C相曲线字符串 p_phase (str): P相曲线字符串 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 scalers (Dict[str, Dict[str, float]]): 归一化参数 device (torch.device): 计算设备 返回: Dict[str, Union[int, float]]: 预测结果 """ model.eval() # 解析曲线 def parse_curve(curve_str): return [float(val) for val in curve_str.strip('|').split('|')] # 处理曲线 def align_curve(curve, target_length, method): if len(curve) == target_length: return np.array(curve) if method == 'linear_interpolation': original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif method == 'jump_point_alignment': return align_by_jump_point_local(curve, target_length) # 使用局部函数避免命名冲突 else: # 默认使用线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) # 跳变点对齐 (局部版本) def align_by_jump_point_local(curve, target_length): curve_array = np.array(curve) jump_idx = find_jump_point_local(curve_array) # 使用局部函数 target_jump_idx = int(target_length * 0.3) left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) if len(curve) + left_pad + right_pad > target_length: trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded # 寻找跳变点 (局部版本) def find_jump_point_local(curve, threshold=0.5): if len(curve) <= 1: # 处理空或单点曲线 return len(curve) // 2 diffs = np.abs(np.diff(curve)) if len(diffs) == 0: # 如果差分后为空 (例如原始曲线只有1个点) return len(curve) // 2 mean_diff = np.mean(diffs) std_diff = np.std(diffs) jump_threshold = mean_diff + threshold * std_diff jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: return len(curve) // 2 # 归一化曲线 def normalize_curve(curve, phase, scalers): if scalers is None: return curve min_val = scalers[phase]['min'] max_val = scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: return np.zeros_like(curve) return (curve - min_val) / range_val # 处理每个相的曲线 curves_data = { 'A相': parse_curve(a_phase), 'B相': parse_curve(b_phase), 'C相': parse_curve(c_phase), 'P相': parse_curve(p_phase) } # 对齐曲线 aligned_curves_data = { phase: align_curve(curves_data[phase], target_length, alignment_method) for phase in curves_data } # 归一化曲线 if scalers is not None: normalized_curves_data = { phase: normalize_curve(aligned_curves_data[phase], phase, scalers) for phase in aligned_curves_data } else: normalized_curves_data = aligned_curves_data # 转换为张量 features_tensor = { phase: torch.FloatTensor(normalized_curves_data[phase]).unsqueeze(0).to(device) for phase in normalized_curves_data } # 预测 with torch.no_grad(): outputs = model(features_tensor) probabilities = F.softmax(outputs, dim=1) predicted_class = outputs.argmax(dim=1).item() confidence_score = probabilities[0, predicted_class].item() return { 'prediction': predicted_class, 'confidence': confidence_score, 'probabilities': probabilities[0].cpu().numpy().tolist() } # 生成样本曲线的图像 def generate_curve_plot(a_phase: str, b_phase: str, c_phase: str, p_phase: str) -> str: """ 生成曲线图像并返回Base64编码的图像 参数: a_phase (str): A相曲线字符串 b_phase (str): B相曲线字符串 c_phase (str): C相曲线字符串 p_phase (str): P相曲线字符串 返回: str: Base64编码的图像 """ # 解析曲线 def parse_curve(curve_str): return [float(val) for val in curve_str.strip('|').split('|')] curves = { 'A相': parse_curve(a_phase), 'B相': parse_curve(b_phase), 'C相': parse_curve(c_phase), 'P相': parse_curve(p_phase) } plt.figure(figsize=(12, 8)) # 创建一个2x2的子图布局 phases_list = ['A相', 'B相', 'C相', 'P相'] for i, phase_name in enumerate(phases_list): plt.subplot(2, 2, i+1) plt.plot(curves[phase_name]) plt.title(f"{phase_name} 曲线") plt.xlabel("时间步") plt.ylabel("值") plt.grid(True) plt.tight_layout() # 将图像保存到内存中的BytesIO对象 buffer = io.BytesIO() plt.savefig(buffer, format='png') plt.close() # 获取图像的Base64编码 buffer.seek(0) img_str = base64.b64encode(buffer.read()).decode('utf-8') return img_str # FastAPI应用 app = FastAPI(title="故障曲线分类预测") templates = Jinja2Templates(directory="templates") class InputData(BaseModel): """API输入数据模型""" a_phase: str b_phase: str c_phase: str p_phase: str # 全局模型和归一化器 global_model = None global_scalers = None def load_model_for_api(model_path: str = 'best_model.pth', num_classes: int = 9): """ 加载已训练的模型 (API专用,避免与主训练流程中的model变量混淆) 参数: model_path (str): 模型文件路径 num_classes (int): 类别数 返回: nn.Module: 加载的模型 """ global global_model if global_model is not None: return global_model try: loaded_model = LSTMCNN(num_classes=num_classes) try: checkpoint = torch.load(model_path, map_location=device, weights_only=False) except FileNotFoundError: logger.error(f"模型文件未找到 (API): {model_path}") raise except Exception as e: logger.error(f"加载检查点时出错 (API): {str(e)}") raise # 加载模型权重 if 'model_state_dict' in checkpoint: loaded_model.load_state_dict(checkpoint['model_state_dict']) else: loaded_model.load_state_dict(checkpoint) logger.info(f"从 {model_path} 加载模型成功 (API), 类别数: {num_classes}") except Exception as e: logger.error(f"加载模型失败 (API): {str(e)}") raise loaded_model = loaded_model.to(device) loaded_model.eval() global_model = loaded_model # 赋值给全局变量 return global_model def load_scalers_for_api(scalers_path: str = 'scalers.pkl'): """ 加载归一化器 (API专用) 参数: scalers_path (str): 归一化器文件路径 返回: Dict[str, Dict[str, float]]: 加载的归一化参数 """ global global_scalers # 修改这里以匹配全局变量名 if global_scalers is not None: return global_scalers try: import pickle with open(scalers_path, 'rb') as f: loaded_scalers = pickle.load(f) # 使用不同的局部变量名 logger.info(f"从 {scalers_path} 加载归一化器成功 (API)") global_scalers = loaded_scalers # 赋值给全局变量 except FileNotFoundError: logger.error(f"归一化器文件未找到: {scalers_path} (API). 将不使用归一化器。") global_scalers = None # 明确设置为None except Exception as e: logger.error(f"加载归一化器失败 (API): {str(e)}") global_scalers = None # 明确设置为None return global_scalers @app.get("/", response_class=HTMLResponse) async def get_index(request: Request): """ 返回主页 """ return templates.TemplateResponse("index.html", {"request": request}) @app.get("/index2", response_class=HTMLResponse) async def get_index2(request: Request): """ 返回增强版主页 """ return templates.TemplateResponse("index2.html", {"request": request}) @app.post("/predict") async def predict_api_endpoint(data: InputData): # 重命名函数以避免与内部predict函数混淆 """ 预测API端点 参数: data (InputData): 输入数据 返回: Dict: 预测结果 """ try: # 加载模型和归一化器 # 假设模型和scalers在启动时或第一次调用时加载 # 如果尚未加载,则加载它们 # 注意:这里的num_classes需要与训练时一致,可以考虑从配置文件读取或从模型检查点获取 # 暂时硬编码为9,如果您的类别数不同,请修改 current_model = load_model_for_api(num_classes=9) # 确保这里的num_classes正确 current_scalers = load_scalers_for_api() if current_model is None: raise HTTPException(status_code=500, detail="模型未能加载") # 预测 result = predict_sample( model=current_model, a_phase=data.a_phase, b_phase=data.b_phase, c_phase=data.c_phase, p_phase=data.p_phase, scalers=current_scalers, # 传递加载的scalers device=device ) # 生成图像 plot_img = generate_curve_plot(data.a_phase, data.b_phase, data.c_phase, data.p_phase) return { "success": True, "prediction": result["prediction"], "confidence": result["confidence"], "probabilities": result["probabilities"], "plot_image": plot_img } except HTTPException as http_exc: # 重新抛出HTTPException raise http_exc except Exception as e: logger.exception("预测时出错 (API)") raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}") def get_data_splits_indices(file_path: str, train_ratio: float, val_ratio: float, random_state: int = 42) -> Tuple[List[int], List[int], List[int]]: """ 计算并返回用于训练、验证和测试集的打乱后的行索引。 索引是相对于数据文件的行号(0-based, 不包括表头)。 """ logger.info("正在为流式数据集确定行索引...") try: with open(file_path, 'r', encoding='utf-8') as f: # 计算总行数 (跳过表头) total_lines = sum(1 for _ in f) - 1 if total_lines <= 0: raise ValueError("数据文件为空或只有表头。") except FileNotFoundError: logger.error(f"数据文件未找到: {file_path}") raise except Exception as e: logger.error(f"读取文件行数时出错: {e}") raise all_indices = list(range(total_lines)) # 使用numpy进行可复现的打乱 rng = np.random.RandomState(random_state) rng.shuffle(all_indices) # 计算每个集合的样本数 train_count = int(total_lines * train_ratio) val_count = int(total_lines * val_ratio) train_indices = all_indices[:train_count] val_indices = all_indices[train_count : train_count + val_count] test_indices = all_indices[train_count + val_count:] logger.info(f"总数据行数: {total_lines}") logger.info(f"训练集索引数: {len(train_indices)}, 验证集索引数: {len(val_indices)}, 测试集索引数: {len(test_indices)}") # 确保没有重叠且覆盖所有索引 assert len(set(train_indices) | set(val_indices) | set(test_indices)) == total_lines assert len(set(train_indices) & set(val_indices)) == 0 assert len(set(train_indices) & set(test_indices)) == 0 assert len(set(val_indices) & set(test_indices)) == 0 return train_indices, val_indices, test_indices # 主函数 def main(): """主函数,用于训练和评估模型""" try: # 设置输出目录 output_dir = "model_output" os.makedirs(output_dir, exist_ok=True) # 定义数据集文件路径 data_path = 'curve_data_pivot_new.csv' if not os.path.exists(data_path): logger.error(f"数据文件 {data_path} 未找到。请确保文件路径正确。") return # 首先进行一次文件扫描,获取基本统计信息和类别数量 logger.info("扫描数据集获取基本信息...") # 使用纯CSV读取而不是pandas,更节省内存 class_counts = {} curve_lengths = {'A相': [], 'B相': [], 'C相': [], 'P相': []} curve_mins = {'A相': float('inf'), 'B相': float('inf'), 'C相': float('inf'), 'P相': float('inf')} curve_maxs = {'A相': float('-inf'), 'B相': float('-inf'), 'C相': float('-inf'), 'P相': float('-inf')} # 仅采样部分数据计算统计值 total_rows = 0 sample_interval = 100 # 每100行采样一次 with open(data_path, 'r', encoding='utf-8') as f: reader = csv.reader(f) header = next(reader) # 跳过表头 for i, row in enumerate(tqdm(reader, desc="扫描数据集")): total_rows += 1 # 统计类别分布 try: label = int(row[3]) if label in class_counts: class_counts[label] += 1 else: class_counts[label] = 1 except (IndexError, ValueError) as e: logger.warning(f"解析标签时跳过行 {i+1}: {e} - Row: {row}") continue # 采样部分数据计算统计值 if i % sample_interval == 0: for j, phase in enumerate(['A相', 'B相', 'C相', 'P相']): try: curve_str = row[j+4] values = [float(val) for val in curve_str.strip('|').split('|') if val] # 确保val不为空 if not values: continue # 如果解析后为空列表,跳过 curve_lengths[phase].append(len(values)) # 更新最小值和最大值 min_val = min(values) max_val = max(values) curve_mins[phase] = min(curve_mins[phase], min_val) curve_maxs[phase] = max(curve_maxs[phase], max_val) except (IndexError, ValueError) as e: logger.warning(f"解析曲线数据时跳过行 {i+1}, 相 {phase}: {e} - Row: {row}") continue if total_rows == 0: logger.error("数据文件为空或未能读取任何行。") return if not class_counts: logger.error("未能从数据文件中解析任何类别标签。请检查数据格式。") return # 计算每个相的长度统计和归一化参数 scalers = {} for phase in ['A相', 'B相', 'C相', 'P相']: # 确保在max/min仍为inf/-inf时(即没有有效数据点)提供默认值 scalers[phase] = { 'min': curve_mins[phase] if curve_mins[phase] != float('inf') else 0.0, 'max': curve_maxs[phase] if curve_maxs[phase] != float('-inf') else 1.0 } # 如果min和max相同,调整max以避免除零 if scalers[phase]['min'] == scalers[phase]['max']: scalers[phase]['max'] += 1e-6 # 添加一个小值 # 使用计算得到的统计信息 logger.info(f"数据集总行数: {total_rows}") logger.info(f"类别分布: {class_counts}") # 确定 num_classes 基于实际观察到的最大类别索引 if not class_counts: logger.error("无法确定类别数量,因为类别分布为空。") return # 或者设置一个默认值,但这可能不正确 num_classes = max(class_counts.keys()) + 1 logger.info(f"根据数据推断的类别数量 (max_label + 1): {num_classes}") # 保存归一化器 import pickle scalers_path = os.path.join(output_dir, 'scalers.pkl') with open(scalers_path, 'wb') as f: pickle.dump(scalers, f) logger.info(f"归一化器已保存到 {scalers_path}") # 计算类别权重 class_weights_tensor = torch.zeros(num_classes) # 使用推断的num_classes # 确保所有类别(从0到num_classes-1)都有一个权重,即使它们在数据中不存在 for class_idx in range(num_classes): class_count = class_counts.get(class_idx, 0) # 获取计数,如果不存在则为0 if class_count > 0: # 使用实际观察到的类别数量 len(class_counts) 进行加权 class_weights_tensor[class_idx] = total_rows / (len(class_counts) * class_count) else: # 对于数据中未出现的类别,可以设置权重为0或1,或基于其他策略 # 设置为0意味着损失函数中不会考虑这些类别(如果它们意外出现) # 设置为1意味着它们具有“平均”权重(如果它们是可能的但稀有的类别) class_weights_tensor[class_idx] = 1.0 # 或0.0,取决于策略 logger.warning(f"类别 {class_idx} 在数据中未找到,权重设置为1.0。") logger.info(f"类别权重: {class_weights_tensor}") # 获取打乱后的数据划分索引 train_indices, val_indices, test_indices = get_data_splits_indices( file_path=data_path, train_ratio=0.7, val_ratio=0.15, random_state=42 ) # 创建流式训练集 train_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=train_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=True, scalers=scalers, buffer_size=1000 ) # 创建流式验证集 val_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=val_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=False, scalers=scalers, buffer_size=1000 ) # 创建流式测试集 test_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=test_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=False, scalers=scalers, buffer_size=1000 ) # 创建数据加载器 # DataLoader的batch_size=None表示数据集本身产生批次 train_loader = DataLoader( train_dataset, batch_size=None, num_workers=0 # 多进程在Windows上可能需要特殊处理,暂时设为0 ) val_loader = DataLoader( val_dataset, batch_size=None, num_workers=0 ) test_loader = DataLoader( test_dataset, batch_size=None, num_workers=0 ) # 创建模型 model_instance = LSTMCNN( # 使用不同的变量名 input_channels=4, seq_length=150, num_classes=num_classes, cnn_out_channels=128, lstm_hidden_size=set_lstm_hidden_size, lstm_layers=set_lstm_layers, dropout=set_dropout, bidirectional=True ) # 打印模型结构 logger.info(f"模型结构: {model_instance}") # 训练模型 logger.info("开始训练模型...") checkpoint_path = os.path.join(output_dir, 'best_model.pth') history = train_model( model=model_instance, train_loader=train_loader, val_loader=val_loader, num_epochs=100, learning_rate=0.001, weight_decay=0.01, patience=20, checkpoint_path=checkpoint_path, class_weights=class_weights_tensor, device=device ) # 绘制训练历史 history_path = os.path.join(output_dir, 'training_history.png') plot_training_history(history, save_path=history_path) logger.info(f"训练历史已保存到 {history_path}") # 加载最佳模型进行评估 best_model_for_eval = LSTMCNN( # 使用不同的变量名 input_channels=4, seq_length=150, num_classes=num_classes, cnn_out_channels=128, lstm_hidden_size=set_lstm_hidden_size, lstm_layers=set_lstm_layers, dropout=set_dropout, bidirectional=True ) if not os.path.exists(checkpoint_path): logger.error(f"最佳模型检查点 {checkpoint_path} 未找到。跳过评估。") else: checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # 兼容旧的只保存模型权重的格式和新的包含 'model_state_dict' 的格式 if 'model_state_dict' in checkpoint: best_model_for_eval.load_state_dict(checkpoint['model_state_dict']) else: best_model_for_eval.load_state_dict(checkpoint) logger.info(f"已从 {checkpoint_path} 加载最佳模型权重进行评估。") best_model_for_eval = best_model_for_eval.to(device) # 评估模型 logger.info("评估模型在测试集上的表现...") evaluation_dir = os.path.join(output_dir, 'evaluation') os.makedirs(evaluation_dir, exist_ok=True) test_results = evaluate_model( model=best_model_for_eval, data_loader=test_loader, device=device, output_dir=evaluation_dir ) logger.info(f"测试结果 - 损失: {test_results['loss']:.4f}, 准确率: {test_results['accuracy']:.4f}, F1分数: {test_results['f1']:.4f}") # 保存评估结果 evaluation_results_data = { # 使用不同的变量名 'loss': float(test_results['loss']), 'accuracy': float(test_results['accuracy']), 'f1': float(test_results['f1']), 'confusion_matrix': test_results['confusion_matrix'].tolist() } with open(os.path.join(evaluation_dir, 'evaluation_results.json'), 'w') as f: json.dump(evaluation_results_data, f, indent=4) logger.info("模型训练和评估完成!") except Exception as e: logger.exception("训练过程中出错") # raise # 在生产环境中可能不想重新抛出,而是记录并优雅退出 # API服务器启动函数 def run_api_server(host="0.0.0.0", port=8000, num_classes_api=9): # 添加num_classes参数 """ 启动API服务器 参数: host (str): 主机 port (int): 端口 num_classes_api (int): API模型加载时使用的类别数 """ # 在服务器启动时预加载模型和scalers logger.info("正在为API预加载模型和归一化器...") load_model_for_api(model_path=os.path.join("model_output", 'best_model.pth')) load_scalers_for_api(scalers_path=os.path.join("model_output", 'scalers.pkl')) logger.info("模型和归一化器预加载完成 (API)。") import uvicorn uvicorn.run(app, host=host, port=port) def configure_chinese_font(): """配置中文字体,根据系统可用字体选择合适的中文字体""" import matplotlib.font_manager as fm import platform # 获取所有可用的字体 font_names = [f.name for f in fm.fontManager.ttflist] # 按优先级排序的中文字体列表 chinese_fonts = [ 'Noto Sans CJK SC', # 思源黑体简体中文 'Noto Sans CJK JP', # 思源黑体日文(包含中文) 'Noto Sans CJK TC', # 思源黑体繁体中文 'WenQuanYi Micro Hei', # 文泉驿微米黑 'Microsoft YaHei', # 微软雅黑 'SimHei', # 黑体 'SimSun', # 宋体 'Arial Unicode MS', # Arial Unicode ] # 检查系统中可用的中文字体 available_chinese_fonts = [font for font in chinese_fonts if font in font_names] if available_chinese_fonts: # 使用找到的第一个可用中文字体 plt.rcParams['font.sans-serif'] = available_chinese_fonts print(f"使用中文字体: {available_chinese_fonts[0]}") else: # 如果没有找到任何预定义的中文字体,使用系统默认字体 print("警告: 未找到预设的中文字体,可能无法正确显示中文") plt.rcParams['axes.unicode_minus'] = False plt.rcParams['font.family'] = 'sans-serif' # 返回使用的字体,便于调试 return available_chinese_fonts[0] if available_chinese_fonts else None # 如果直接运行此脚本 if __name__ == "__main__": configure_chinese_font() # 在Windows上确保多进程支持 if os.name == 'nt': # 更明确地检查操作系统 multiprocessing.freeze_support() import argparse parser = argparse.ArgumentParser(description="故障曲线分类预测系统") parser.add_argument("--mode", type=str, choices=["train", "serve"], default="train", help="运行模式: train (训练模型) 或 serve (启动API服务)") parser.add_argument("--port", type=int, default=8000, help="API服务端口") parser.add_argument("--host", type=str, default="0.0.0.0", help="API服务主机") # 添加一个参数来传递类别数给API服务,因为这可能与训练时动态确定的不同 # 或者,更好的做法是从保存的检查点或配置文件中读取 args = parser.parse_args() if args.mode == "train": main() elif args.mode == "serve": # 确保输出目录存在,以便API可以加载模型/scalers output_dir = "model_output" if not os.path.exists(os.path.join(output_dir, 'best_model.pth')): logger.warning(f"API模式:模型文件 {os.path.join(output_dir, 'best_model.pth')} 未找到。请先训练模型。") # 可以选择退出或尝试继续(如果API逻辑能处理模型不存在的情况) if not os.path.exists(os.path.join(output_dir, 'scalers.pkl')): logger.warning(f"API模式:归一化文件 {os.path.join(output_dir, 'scalers.pkl')} 未找到。预测可能不准确。") run_api_server(host=args.host, port=args.port)

导入所需的库

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, IterableDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import os
import json
import time
import csv
from tqdm import tqdm
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import uvicorn
import logging
import io
import base64
from typing import List, Dict, Union, Tuple, Optional
import gc
import multiprocessing
from functools import partial
from concurrent.futures import ProcessPoolExecutor

set_lstm_hidden_size=128
set_lstm_layers = 6
set_dropout=0.5

配置日志

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("fault_prediction.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(name)

设置随机种子以确保可重现性

def set_seed(seed=42):
"""
设置所有随机种子以确保结果可重现

text
参数: seed (int): 随机种子值 """ np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # 如果使用多GPU torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False

set_seed()

设备配置

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"使用设备: {device}")

检查是否有多个GPU可用

if torch.cuda.device_count() > 1:
logger.info(f"使用 {torch.cuda.device_count()} 个GPU进行训练")
multi_gpu = True
else:
multi_gpu = False

数据集定义

class FaultDataset(Dataset):
"""
用于处理故障数据的自定义数据集类
"""
def init(self,
file_path: str = None,
data: pd.DataFrame = None,
normalize: bool = True,
target_length: int = 150,
alignment_method: str = 'linear_interpolation',
augment: bool = False,
scalers: dict = None):
"""
初始化数据集

text
参数: file_path (str): CSV文件路径,可以为None如果直接提供dataframe data (pd.DataFrame): 直接提供的数据帧,可以为None如果提供file_path normalize (bool): 是否对数据进行归一化处理 target_length (int): 目标序列长度,用于对齐 alignment_method (str): 对齐方法,可选值为: 'linear_interpolation': 线性插值 'pad_zeros': 零填充 'dynamic_time_warping': 动态时间规整 'fourier_transform': 傅里叶变换重采样 augment (bool): 是否进行数据增强 scalers (dict): 已经拟合好的归一化器,用于测试或验证集 """ self.target_length = target_length self.normalize = normalize self.alignment_method = alignment_method self.augment = augment # 从文件加载数据或使用提供的数据帧 if file_path is not None: self.data = self.load_data(file_path) elif data is not None: self.data = data else: raise ValueError("必须提供file_path或data参数之一") # 分离特征和标签 self.features = self.data.drop(['data_id', 'measure_time', 'turnout_num', 'label'], axis=1) self.labels = self.data['label'].values # 存储原始曲线长度 self.original_lengths = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } # 处理和对齐曲线数据 self.processed_features = self.process_curves() # 数据归一化 if self.normalize: if scalers is None: self.scalers = self.fit_scalers() else: self.scalers = scalers # 应用归一化 self.normalized_features = self.apply_normalization() else: self.normalized_features = self.processed_features def load_data(self, file_path: str) -> pd.DataFrame: """ 从CSV文件加载数据,支持大文件分块处理以避免内存问题 参数: file_path (str): CSV文件路径 返回: pd.DataFrame: 加载的数据帧 """ # 检查文件是否存在 if not os.path.exists(file_path): raise FileNotFoundError(f"文件不存在: {file_path}") logger.info(f"从 {file_path} 加载数据...") try: # 对于大文件,使用分块读取以避免内存问题 chunk_size = 10000 # 每次读取的行数 chunks = [] for chunk in tqdm(pd.read_csv(file_path, chunksize=chunk_size), desc="加载数据块"): chunks.append(chunk) data = pd.concat(chunks, ignore_index=True) logger.info(f"数据加载完成。形状: {data.shape}") return data except Exception as e: logger.error(f"加载数据时出错: {str(e)}") raise def parse_curve(self, curve_str: str) -> List[float]: """ 解析曲线字符串为数值列表 参数: curve_str (str): 曲线字符串,格式如 "|0.0|0.15|0.6|..." 返回: List[float]: 解析后的数值列表 """ # 去除首尾的'|'并按'|'分割 values = curve_str.strip('|').split('|') # 转换为浮点数 return [float(val) for val in values] def process_curves(self) -> Dict[str, np.ndarray]: """ 处理所有曲线,解析并对齐到目标长度 返回: Dict[str, np.ndarray]: 处理后的特征字典,键为相名,值为对齐后的数组 """ processed = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } # 向量化处理曲线数据 for phase in processed.keys(): # 预处理字符串 curves = self.features[phase].str.strip('|').str.split('|') # 转换为数值型列表 curves = curves.apply(lambda x: [float(val) for val in x]) # 记录原始长度 self.original_lengths[phase] = curves.apply(len).values # 批量对齐 aligned_curves = [] for curve in tqdm(curves, desc=f"对齐{phase}曲线"): aligned_curves.append(self.align_curve(curve, self.target_length)) processed[phase] = np.array(aligned_curves) logger.info(f"曲线处理完成。每相形状: {processed['A相'].shape}") return processed def align_curve(self, curve: List[float], target_length: int) -> np.ndarray: """ 将曲线对齐到目标长度 参数: curve (List[float]): 原始曲线数据 target_length (int): 目标长度 返回: np.ndarray: 对齐后的曲线 """ if len(curve) == target_length: return np.array(curve) # 根据选择的方法进行对齐 if self.alignment_method == 'linear_interpolation': # 线性插值对齐 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif self.alignment_method == 'pad_zeros': # 零填充对齐 if len(curve) > target_length: # 如果曲线长度大于目标长度,则截断 return np.array(curve[:target_length]) else: # 如果曲线长度小于目标长度,则填充零 padded = np.zeros(target_length) padded[:len(curve)] = curve return padded elif self.alignment_method == 'dynamic_time_warping': # 动态时间规整需要第三方库,这里只实现简化版本 scale_factor = target_length / len(curve) indices = np.round(np.arange(0, len(curve), 1/scale_factor)).astype(int) indices = indices[:target_length] return np.array([curve[min(i, len(curve)-1)] for i in indices]) elif self.alignment_method == 'fourier_transform': # 使用傅里叶变换进行重采样 try: from scipy import fftpack # 获取频域表示 curve_fft = fftpack.fft(curve) # 如果目标长度大于原始长度,补零 if target_length > len(curve): pad_size = target_length - len(curve) curve_fft_padded = np.pad(curve_fft, (0, pad_size), 'constant') # 如果目标长度小于原始长度,截断高频部分 else: curve_fft_padded = curve_fft[:target_length] # 逆变换回时域 return fftpack.ifft(curve_fft_padded).real except ImportError: logger.warning("scipy未安装,回退到线性插值") # 回退到线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif self.alignment_method == 'jump_point_alignment': return self.align_by_jump_point(curve, target_length) else: # 默认使用线性插值 logger.warning(f"未知的对齐方法: {self.alignment_method},使用线性插值") original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) def fit_scalers(self) -> Dict[str, Dict[str, float]]: """ 为每个相拟合归一化器,优化内存使用 返回: Dict[str, Dict[str, float]]: 每个相的归一化参数 """ scalers = {} for phase in self.processed_features.keys(): data = self.processed_features[phase] min_val = np.min(data) max_val = np.max(data) scalers[phase] = {'min': min_val, 'max': max_val} logger.info("归一化器拟合完成") return scalers def apply_normalization(self) -> Dict[str, np.ndarray]: """ 应用归一化到处理后的特征,优化内存使用 返回: Dict[str, np.ndarray]: 归一化后的特征 """ normalized = {} for phase in self.processed_features.keys(): data = self.processed_features[phase] min_val = self.scalers[phase]['min'] max_val = self.scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: normalized[phase] = np.zeros_like(data) else: normalized[phase] = (data - min_val) / range_val logger.info("特征归一化完成") return normalized def augment_data(self, data: Dict[str, np.ndarray], label: int) -> Tuple[Dict[str, np.ndarray], int]: """ 对数据进行增强 参数: data (Dict[str, np.ndarray]): 要增强的数据 label (int): 对应的标签 返回: Tuple[Dict[str, np.ndarray], int]: 增强后的数据和标签 """ # 只有在训练阶段才进行数据增强 if not self.augment: return data, label # 随机选择增强方法 augmentation_choice = np.random.rand() # 对数据的一个副本进行操作 augmented_data = {} for phase in data.keys(): augmented_data[phase] = data[phase].copy() if augmentation_choice < 0.25: # 方法1: 添加随机噪声 noise_level = np.random.uniform(0.01, 0.03) for phase in augmented_data.keys(): noise = np.random.normal(0, noise_level, size=augmented_data[phase].shape) augmented_data[phase] = augmented_data[phase] + noise elif augmentation_choice < 0.5: # 方法2: 时间滑动(左右移动) shift_amount = int(np.random.uniform(-10, 10)) for phase in augmented_data.keys(): if shift_amount > 0: augmented_data[phase] = np.pad(augmented_data[phase], (0, shift_amount), 'constant')[:self.target_length] else: augmented_data[phase] = np.pad(augmented_data[phase], (-shift_amount, 0), 'constant')[:-shift_amount if shift_amount != 0 else None] elif augmentation_choice < 0.75: # 方法3: 幅度缩放 scale_factor = np.random.uniform(0.9, 1.1) for phase in augmented_data.keys(): augmented_data[phase] = augmented_data[phase] * scale_factor else: # 方法4: 反转部分片段 if self.target_length > 10: # 随机选择一个片段进行反转 start_idx = np.random.randint(0, self.target_length - 10) end_idx = np.random.randint(start_idx + 5, min(start_idx + 10, self.target_length)) for phase in augmented_data.keys(): segment = augmented_data[phase][start_idx:end_idx].copy() augmented_data[phase][start_idx:end_idx] = segment[::-1] # 反转片段 return augmented_data, label def __len__(self): """返回数据集的大小""" return len(self.labels) def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], int]: """ 获取指定索引的样本 参数: idx (int): 样本索引 返回: Tuple[Dict[str, torch.Tensor], int]: 特征字典和标签 """ # 获取所有相的数据 sample = { 'A相': self.normalized_features['A相'][idx], 'B相': self.normalized_features['B相'][idx], 'C相': self.normalized_features['C相'][idx], 'P相': self.normalized_features['P相'][idx], } label = self.labels[idx] # 数据增强 if self.augment and np.random.rand() < 0.5: # 50%的概率进行增强 sample, label = self.augment_data(sample, label) # 转换为张量 tensor_sample = {phase: torch.FloatTensor(sample[phase]) for phase in sample.keys()} return tensor_sample, label def get_curve_length_stats(self) -> Dict[str, Dict[str, Union[float, int]]]: """ 获取曲线长度的统计信息 返回: Dict[str, Dict[str, Union[float, int]]]: 每个相的长度统计信息 """ stats = {} for phase in self.original_lengths.keys(): lengths = self.original_lengths[phase] stats[phase] = { 'min': np.min(lengths), 'max': np.max(lengths), 'mean': np.mean(lengths), 'median': np.median(lengths), 'std': np.std(lengths) } return stats def get_class_distribution(self) -> Dict[int, int]: """ 获取类别分布 返回: Dict[int, int]: 类别分布字典,键为类别,值为样本数 """ unique, counts = np.unique(self.labels, return_counts=True) return dict(zip(unique, counts)) @classmethod def create_train_val_test_sets(cls, file_path: str, train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15, target_length: int = 150, alignment_method: str = 'linear_interpolation', stratify: bool = True, balance_method: Optional[str] = None, compute_class_weights: bool = False, random_state: int = 42) -> Tuple[Dataset, Dataset, Dataset, Optional[torch.Tensor]]: """ 创建训练、验证和测试集 参数: file_path (str): CSV文件路径 train_ratio (float): 训练集比例 val_ratio (float): 验证集比例 test_ratio (float): 测试集比例 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 stratify (bool): 是否进行分层抽样 balance_method (str, 可选): 处理不平衡的方法,可选值: 'undersample': 欠采样 'oversample': 过采样 None: 不进行平衡处理 compute_class_weights (bool): 是否计算类别权重 random_state (int): 随机种子 返回: Tuple[Dataset, Dataset, Dataset, Optional[torch.Tensor]]: 训练集、验证集、测试集和类别权重 """ # 加载数据 temp_dataset = cls(file_path=file_path, normalize=False) # 获取标签分布 class_distribution = temp_dataset.get_class_distribution() logger.info(f"类别分布: {class_distribution}") # 创建数据帧,包含所有数据 df = temp_dataset.data # 分层划分数据集 if stratify: stratify_labels = df['label'] else: stratify_labels = None # 首先划分出测试集 train_val_df, test_df = train_test_split( df, test_size=test_ratio, random_state=random_state, stratify=stratify_labels ) # 然后从剩余数据中划分出验证集 if stratify: stratify_labels = train_val_df['label'] else: stratify_labels = None train_df, val_df = train_test_split( train_val_df, test_size=val_ratio/(train_ratio+val_ratio), random_state=random_state, stratify=stratify_labels ) logger.info(f"划分完成。训练集: {len(train_df)},验证集: {len(val_df)},测试集: {len(test_df)}") # 处理不平衡数据 if balance_method == 'undersample': # 欠采样:减少多数类的样本 train_df = cls._undersample(train_df) logger.info(f"欠采样后的训练集大小: {len(train_df)}") elif balance_method == 'oversample': # 过采样:增加少数类的样本 train_df = cls._oversample(train_df) logger.info(f"过采样后的训练集大小: {len(train_df)}") # 创建完整的数据集 # 首先创建训练集并拟合归一化器 train_dataset = cls( data=train_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=True ) # 然后使用训练集的归一化器创建验证集和测试集 val_dataset = cls( data=val_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=False, scalers=train_dataset.scalers ) test_dataset = cls( data=test_df, normalize=True, target_length=target_length, alignment_method=alignment_method, augment=False, scalers=train_dataset.scalers ) # 计算类别权重 class_weights = None if compute_class_weights: class_distribution = train_dataset.get_class_distribution() num_samples = len(train_dataset) num_classes = len(class_distribution) class_weights = torch.zeros(num_classes) for class_idx, class_count in class_distribution.items(): class_weights[class_idx] = num_samples / (num_classes * class_count) logger.info(f"类别权重: {class_weights}") return train_dataset, val_dataset, test_dataset, class_weights @staticmethod def _undersample(df: pd.DataFrame) -> pd.DataFrame: """ 对数据帧进行欠采样以平衡类别 参数: df (pd.DataFrame): 要处理的数据帧 返回: pd.DataFrame: 欠采样后的数据帧 """ class_counts = df['label'].value_counts() min_count = class_counts.min() # 对每个类别进行欠采样 balanced_dfs = [] for class_label in class_counts.index: class_df = df[df['label'] == class_label] # 如果类别样本数大于最小数量,随机采样 if len(class_df) > min_count: sampled_df = class_df.sample(min_count, random_state=42) balanced_dfs.append(sampled_df) else: balanced_dfs.append(class_df) # 合并所有采样后的数据 return pd.concat(balanced_dfs, ignore_index=True) @staticmethod def _oversample(df: pd.DataFrame) -> pd.DataFrame: """ 对数据帧进行过采样以平衡类别 参数: df (pd.DataFrame): 要处理的数据帧 返回: pd.DataFrame: 过采样后的数据帧 """ class_counts = df['label'].value_counts() max_count = class_counts.max() # 对每个类别进行过采样 balanced_dfs = [] for class_label in class_counts.index: class_df = df[df['label'] == class_label] # 如果类别样本数小于最大数量,随机重复采样 if len(class_df) < max_count: # 计算需要重复的次数 n_samples = max_count - len(class_df) # 使用替换采样 extra_samples = class_df.sample(n_samples, replace=True, random_state=42) balanced_dfs.append(pd.concat([class_df, extra_samples], ignore_index=True)) else: balanced_dfs.append(class_df) # 合并所有采样后的数据 return pd.concat(balanced_dfs, ignore_index=True) def visualization_sample(self, idx: int, save_path: Optional[str] = None) -> None: """ 可视化指定索引的样本 参数: idx (int): 样本索引 save_path (str, 可选): 保存路径,不提供则显示图像 """ sample, label = self[idx] plt.figure(figsize=(12, 8)) # 创建一个2x2的子图布局 phases = ['A相', 'B相', 'C相', 'P相'] for i, phase in enumerate(phases): plt.subplot(2, 2, i+1) plt.plot(sample[phase].numpy()) plt.title(f"{phase} 曲线") plt.xlabel("时间步") plt.ylabel("值") plt.grid(True) plt.suptitle(f"样本 {idx}, 标签: {label}") plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show() def find_jump_point(self, curve: np.ndarray, threshold: float = 0.5) -> int: """ 寻找曲线中的跳变点 参数: curve: 曲线数据 threshold: 跳变阈值(百分比变化) 返回: int: 跳变点索引 """ # 计算差分 diffs = np.abs(np.diff(curve)) # 计算均值和标准差 mean_diff = np.mean(diffs) std_diff = np.std(diffs) # 设定跳变阈值(可调整) jump_threshold = mean_diff + threshold * std_diff # 找出超过阈值的点 jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: # 返回最显著的跳变点 return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: # 没有显著跳变点,返回中点 return len(curve) // 2 def align_by_jump_point(self, curve: List[float], target_length: int) -> np.ndarray: """ 以跳变点为中心对齐曲线 参数: curve: 原始曲线 target_length: 目标长度 返回: 对齐后的曲线 """ curve_array = np.array(curve) # 找到跳变点 jump_idx = self.find_jump_point(curve_array) # 计算对齐后跳变点位置(例如,放在30%处) target_jump_idx = int(target_length * 0.3) # 计算需要的填充 left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) # 如果原始曲线太长,需要截断 if len(curve) + left_pad + right_pad > target_length: # 计算截断量 trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array # 填充 padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded

class StreamingFaultDataset(IterableDataset):
"""
用于流式处理大数据集的自定义数据集类,避免一次性加载所有数据到内存
"""
def init(self,
file_path: str,
assigned_indices_list: List[int],
batch_size: int = 32,
target_length: int = 150,
alignment_method: str = 'jump_point_alignment',
is_train: bool = True,
scalers: dict = None,
buffer_size: int = 1000):
"""
初始化流式数据集

text
参数: file_path (str): CSV文件路径 batch_size (int): 批次大小 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 is_train (bool): 是否为训练集 scalers (dict): 归一化参数 buffer_size (int): 缓冲区大小,用于加载和处理数据 train_ratio (float): 训练集比例 val_ratio (float): 验证集比例 test_ratio (float): 测试集比例 data_split (str): 数据集划分,'train', 'val', 或 'test' """ self.file_path = file_path self.assigned_indices = set(assigned_indices_list) self.num_samples_for_this_split = len(self.assigned_indices) self.batch_size = batch_size self.target_length = target_length self.alignment_method = alignment_method self.is_train = is_train self.scalers = scalers self.buffer_size = buffer_size logger.info(f"StreamingFaultDataset 初始化完成。将处理 {self.num_samples_for_this_split} 个样本。") def _find_jump_point(self, curve: np.ndarray, threshold: float = 0.5) -> int: """寻找曲线中的跳变点""" diffs = np.abs(np.diff(curve)) if len(diffs) == 0: return 0 mean_diff = np.mean(diffs) std_diff = np.std(diffs) jump_threshold = mean_diff + threshold * std_diff jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: return len(curve) // 2 def _align_by_jump_point(self, curve: List[float], target_length: int) -> np.ndarray: """以跳变点为中心对齐曲线""" curve_array = np.array(curve) jump_idx = self._find_jump_point(curve_array) target_jump_idx = int(target_length * 0.3) left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) if len(curve) + left_pad + right_pad > target_length: trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded def __iter__(self): """迭代器方法,按照数据集划分逐行读取文件并生成批次""" # 用于记录当前处理的行号 current_line = 0 # 缓冲区 buffer_features = { 'A相': [], 'B相': [], 'C相': [], 'P相': [] } buffer_labels = [] # 打开文件 with open(self.file_path, 'r', encoding='utf-8') as f: # 跳过标题行 next(f) # 逐行读取 for line_idx, line in enumerate(f): if line_idx not in self.assigned_indices: continue # 如果当前行不属于这个数据集划分,则跳过 try: # 解析行 row = line.strip().split(',') label = int(row[3]) # 解析和对齐四个相的曲线 a_phase = self._parse_and_align(row[4], 'A相') b_phase = self._parse_and_align(row[5], 'B相') c_phase = self._parse_and_align(row[6], 'C相') p_phase = self._parse_and_align(row[7], 'P相') # 添加到缓冲区 buffer_features['A相'].append(a_phase) buffer_features['B相'].append(b_phase) buffer_features['C相'].append(c_phase) buffer_features['P相'].append(p_phase) buffer_labels.append(label) # 当缓冲区满时,生成批次 if len(buffer_labels) >= self.batch_size: # 转换为张量 tensor_features = { phase: torch.FloatTensor(buffer_features[phase][:self.batch_size]) for phase in buffer_features } tensor_labels = torch.LongTensor(buffer_labels[:self.batch_size]) # 清空缓冲区,保留剩余样本 for phase in buffer_features: buffer_features[phase] = buffer_features[phase][self.batch_size:] buffer_labels = buffer_labels[self.batch_size:] # 如果是训练集且需要数据增强 if self.is_train and np.random.rand() < 0.3: # 30%概率执行增强 tensor_features = self._augment_batch(tensor_features) yield tensor_features, tensor_labels except Exception as e: logger.warning(f"跳过行 {line_idx}: {e}") # 处理剩余数据 if len(buffer_labels) > 0: # 修改:确保只要缓冲区有数据就处理 tensor_features = { phase: torch.FloatTensor(buffer_features[phase]) # 使用所有剩余数据 for phase in buffer_features } tensor_labels = torch.LongTensor(buffer_labels) # 使用所有剩余数据 if self.is_train and np.random.rand() < 0.3: tensor_features = self._augment_batch(tensor_features) yield tensor_features, tensor_labels def _parse_and_align(self, curve_str: str, phase: str) -> np.ndarray: """解析曲线字符串并对齐到目标长度""" # 解析曲线 values = curve_str.strip('|').split('|') curve = [float(val) for val in values] # 对齐曲线 if self.alignment_method == 'jump_point_alignment': aligned = self._align_by_jump_point(curve, self.target_length) else: # 默认使用线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, self.target_length) aligned = np.interp(new_indices, original_indices, curve) # 归一化 if self.scalers is not None: min_val = self.scalers[phase]['min'] max_val = self.scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: normalized = np.zeros_like(aligned) else: normalized = (aligned - min_val) / range_val return normalized return aligned def _augment_batch(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """对批次数据进行增强""" augmentation_choice = np.random.rand() if augmentation_choice < 0.25: # 添加随机噪声 noise_level = np.random.uniform(0.01, 0.03) for phase in features: noise = torch.randn_like(features[phase]) * noise_level features[phase] = features[phase] + noise elif augmentation_choice < 0.5: # 时间滑动 shift_amount = int(np.random.uniform(-5, 5)) for phase in features: if shift_amount > 0: features[phase] = torch.cat([ torch.zeros(features[phase].size(0), shift_amount, device=features[phase].device), features[phase][:, :-shift_amount] ], dim=1) elif shift_amount < 0: features[phase] = torch.cat([ features[phase][:, -shift_amount:], torch.zeros(features[phase].size(0), -shift_amount, device=features[phase].device) ], dim=1) elif augmentation_choice < 0.75: # 幅度缩放 scale_factor = np.random.uniform(0.9, 1.1) for phase in features: features[phase] = features[phase] * scale_factor else: # 反转部分片段 if self.target_length > 10: start_idx = np.random.randint(0, self.target_length - 10) end_idx = np.random.randint(start_idx + 5, min(start_idx + 10, self.target_length)) for phase in features: segment = features[phase].clone()[:, start_idx:end_idx] features[phase][:, start_idx:end_idx] = torch.flip(segment, [1]) return features

定义数据集的collate_fn,用于批处理

def collate_fn(batch):
"""
自定义collate函数,处理不同长度的序列

text
参数: batch: 批次数据 返回: features_batch: 批次特征 labels_batch: 批次标签 """ # 分离特征和标签 features = [item[0] for item in batch] labels = [item[1] for item in batch] # 处理每个相的数据 features_batch = { 'A相': torch.stack([sample['A相'] for sample in features]), 'B相': torch.stack([sample['B相'] for sample in features]), 'C相': torch.stack([sample['C相'] for sample in features]), 'P相': torch.stack([sample['P相'] for sample in features]) } # 转换标签为张量 labels_batch = torch.LongTensor(labels) return features_batch, labels_batch

模型定义

class AttentionLayer(nn.Module):
"""
注意力层,用于关注序列中的重要部分
"""
def init(self, hidden_size: int):
"""
初始化注意力层

text
参数: hidden_size (int): 隐藏层大小 """ super(AttentionLayer, self).__init__() # 注意力计算所需的线性层 self.attn = nn.Linear(hidden_size * 2, hidden_size) self.v = nn.Parameter(torch.rand(hidden_size)) # 初始化参数 stdv = 1. / (hidden_size ** 0.5) self.v.data.uniform_(-stdv, stdv) def forward(self, hidden: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor: """ 前向传播 参数: hidden (torch.Tensor): 隐藏状态 [batch_size, 1, hidden_size] encoder_outputs (torch.Tensor): 编码器输出 [batch_size, seq_len, hidden_size] 返回: torch.Tensor: 注意力加权后的上下文向量 [batch_size, 1, seq_len] """ batch_size = encoder_outputs.size(0) seq_len = encoder_outputs.size(1) # 将hidden复制seq_len次,使其与encoder_outputs的序列长度匹配 # 从 [batch_size, 1, hidden_size] 变为 [batch_size, seq_len, hidden_size] hidden_expanded = hidden.repeat(1, seq_len, 1) # 拼接hidden和encoder_outputs # [batch_size, seq_len, hidden_size*2] combined = torch.cat([hidden_expanded, encoder_outputs], dim=2) # 计算注意力能量 # [batch_size, seq_len, hidden_size] energy = torch.tanh(self.attn(combined)) # 计算注意力权重 # [batch_size, seq_len] attn_energies = torch.sum(self.v * energy, dim=2) # 归一化注意力权重 # [batch_size, 1, seq_len] return F.softmax(attn_energies, dim=1).unsqueeze(1) def score(self, hidden: torch.Tensor, encoder_outputs: torch.Tensor) -> torch.Tensor: """ 计算注意力得分 - 此方法已不再使用,保留为了兼容性 参数: hidden (torch.Tensor): 隐藏状态 encoder_outputs (torch.Tensor): 编码器输出 返回: torch.Tensor: 注意力得分 """ return self.forward(hidden, encoder_outputs).squeeze(1)

class ResidualBlock(nn.Module):
"""
残差块,用于构建深层CNN
"""
def init(self, in_channels: int, out_channels: int):
"""
初始化残差块

text
参数: in_channels (int): 输入通道数 out_channels (int): 输出通道数 """ super(ResidualBlock, self).__init__() # 第一个卷积层 self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm1d(out_channels) # 第二个卷积层 self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm1d(out_channels) # 如果输入和输出通道数不同,需要一个1x1卷积进行调整 self.shortcut = nn.Sequential() if in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size=1), nn.BatchNorm1d(out_channels) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 前向传播 参数: x (torch.Tensor): 输入张量 返回: torch.Tensor: 残差块的输出 """ # 保存原始输入用于残差连接 residual = self.shortcut(x) # 卷积路径 out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) # 添加残差连接并应用ReLU out += residual out = F.relu(out) return out

class LSTMCNN(nn.Module):
"""
结合CNN和LSTM的模型,用于处理时序数据
"""
def init(self, input_channels: int = 4, seq_length: int = 150, num_classes: int = 9,
cnn_out_channels: int = 128, lstm_hidden_size: int = 128, lstm_layers: int = 2,
dropout: float = 0.5, bidirectional: bool = True):
"""
初始化LSTMCNN模型

text
参数: input_channels (int): 输入通道数,对应4个相 seq_length (int): 序列长度 num_classes (int): 分类类别数 cnn_out_channels (int): CNN输出通道数 lstm_hidden_size (int): LSTM隐藏层大小 lstm_layers (int): LSTM层数 dropout (float): Dropout比率 bidirectional (bool): LSTM是否双向 """ super(LSTMCNN, self).__init__() self.seq_length = seq_length self.hidden_size = lstm_hidden_size self.num_layers = lstm_layers self.bidirectional = bidirectional self.num_directions = 2 if bidirectional else 1 # CNN特征提取器 - 每个相单独处理 self.cnn_a = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_b = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_c = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) self.cnn_p = nn.Sequential( ResidualBlock(1, 32), nn.MaxPool1d(2), ResidualBlock(32, 64), nn.MaxPool1d(2), ResidualBlock(64, cnn_out_channels), nn.MaxPool1d(2) ) # 计算CNN后的序列长度 cnn_out_length = seq_length // 8 # 三次MaxPool1d(2)后长度减为原来的1/8 # LSTM层 self.lstm = nn.LSTM( input_size=cnn_out_channels * input_channels, # 4个相的CNN特征拼接 hidden_size=lstm_hidden_size, num_layers=lstm_layers, batch_first=True, dropout=dropout if lstm_layers > 1 else 0, bidirectional=bidirectional ) # 注意力层 self.attention = AttentionLayer(lstm_hidden_size * self.num_directions) # 分类器 self.classifier = nn.Sequential( nn.Linear(lstm_hidden_size * self.num_directions, 256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, num_classes) ) def forward(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: """ 前向传播 参数: x (Dict[str, torch.Tensor]): 输入特征字典,包含四个相的数据 返回: torch.Tensor: 模型输出(类别预测) """ batch_size = x['A相'].size(0) # 每个相的数据调整为CNN输入形状 [B, 1, L] x_a = x['A相'].unsqueeze(1) # [B, 1, L] x_b = x['B相'].unsqueeze(1) # [B, 1, L] x_c = x['C相'].unsqueeze(1) # [B, 1, L] x_p = x['P相'].unsqueeze(1) # [B, 1, L] # 通过CNN提取特征 cnn_out_a = self.cnn_a(x_a) # [B, C, L//8] cnn_out_b = self.cnn_b(x_b) # [B, C, L//8] cnn_out_c = self.cnn_c(x_c) # [B, C, L//8] cnn_out_p = self.cnn_p(x_p) # [B, C, L//8] # 调整形状以便于拼接,[B, C, L//8] -> [B, L//8, C] cnn_out_a = cnn_out_a.transpose(1, 2) cnn_out_b = cnn_out_b.transpose(1, 2) cnn_out_c = cnn_out_c.transpose(1, 2) cnn_out_p = cnn_out_p.transpose(1, 2) # 拼接四个相的特征,在特征维度上合并 cnn_out_combined = torch.cat([cnn_out_a, cnn_out_b, cnn_out_c, cnn_out_p], dim=2) # 通过LSTM处理时序特征 lstm_out, (hidden, _) = self.lstm(cnn_out_combined) # 使用注意力机制处理LSTM输出 if self.bidirectional: # 双向LSTM,需要拼接最后一层的前向和后向 hidden = torch.cat([hidden[-2], hidden[-1]], dim=1).unsqueeze(1) else: # 单向LSTM,只使用最后一层 hidden = hidden[-1].unsqueeze(1) attn_weights = self.attention(hidden, lstm_out) context = torch.bmm(attn_weights, lstm_out) context = context.squeeze(1) # 全连接分类 output = self.classifier(context) return output

训练和评估函数

def train_model(model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
num_epochs: int = 100,
learning_rate: float = 0.001,
weight_decay: float = 0.01,
patience: int = 10,
checkpoint_path: str = 'best_model.pth',
class_weights: Optional[torch.Tensor] = None,
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Dict[str, List]:
"""
训练模型

text
参数: model (nn.Module): 要训练的模型 train_loader (DataLoader): 训练数据加载器 val_loader (DataLoader): 验证数据加载器 num_epochs (int): 训练轮数 learning_rate (float): 学习率 weight_decay (float): 权重衰减系数 patience (int): 早停耐心值 checkpoint_path (str): 模型检查点保存路径 class_weights (torch.Tensor, 可选): 类别权重 device (torch.device): 训练设备 返回: Dict[str, List]: 包含训练历史的字典 """ logger.info(f"开始训练,设备: {device}") # 如果有多GPU,使用DataParallel if torch.cuda.device_count() > 1 and isinstance(device, torch.device) and device.type == 'cuda': logger.info(f"使用 {torch.cuda.device_count()} 个GPU进行训练") model = nn.DataParallel(model) model = model.to(device) # 尝试加载检查点 start_epoch = 0 best_val_f1 = 0.0 early_stop_counter = 0 if os.path.exists(checkpoint_path): try: logger.info(f"从 {checkpoint_path} 加载检查点...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if isinstance(model, nn.DataParallel): model.module.load_state_dict(checkpoint['model_state_dict']) else: model.load_state_dict(checkpoint['model_state_dict']) optimizer_state_dict = checkpoint.get('optimizer_state_dict', None) scheduler_state_dict = checkpoint.get('scheduler_state_dict', None) start_epoch = checkpoint.get('epoch', -1) + 1 best_val_f1 = checkpoint.get('best_val_f1', 0.0) early_stop_counter = checkpoint.get('early_stop_counter', 0) logger.info(f"检查点加载成功。从 Epoch {start_epoch+1} 继续训练。") logger.info(f"恢复的最佳验证F1: {best_val_f1:.4f}, Early stop counter: {early_stop_counter}") except Exception as e: logger.error(f"加载检查点失败: {e}。将从头开始训练。") start_epoch = 0 best_val_f1 = 0.0 early_stop_counter = 0 else: logger.info("未找到检查点,将从头开始训练。") # 定义损失函数和优化器 if class_weights is not None: class_weights = class_weights.to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) logger.info("使用带权重的交叉熵损失") else: criterion = nn.CrossEntropyLoss() logger.info("使用标准交叉熵损失") optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5) # 如果加载了检查点,恢复优化器和调度器状态 if 'optimizer_state_dict' in locals() and optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) logger.info("优化器状态已恢复") if 'scheduler_state_dict' in locals() and scheduler_state_dict is not None: scheduler.load_state_dict(scheduler_state_dict) logger.info("学习率调度器状态已恢复") # 训练历史 history = { 'train_loss': [], 'train_acc': [], 'train_f1': [], 'val_loss': [], 'val_acc': [], 'val_f1': [] } # 早停机制 # best_val_f1 已经从检查点恢复或初始化为0.0 # early_stop_counter 已经从检查点恢复或初始化为0 # 开始训练循环 for epoch in range(start_epoch, num_epochs): model.train() train_loss = 0.0 train_correct = 0 train_total = 0 train_pred_all = [] train_true_all = [] batch_count = 0 # 计数处理的批次数 # 训练阶段 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]") for batch_idx, (features, targets) in enumerate(progress_bar): batch_count += 1 # 更新批次计数 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 清除梯度 optimizer.zero_grad() # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 反向传播和优化 loss.backward() optimizer.step() # 统计 train_loss += loss.item() _, predicted = outputs.max(1) train_correct += predicted.eq(targets).sum().item() train_total += targets.size(0) # 收集预测和真实标签用于计算F1 train_pred_all.extend(predicted.cpu().numpy()) train_true_all.extend(targets.cpu().numpy()) # 更新进度条 - 使用当前批次平均值 current_train_loss_avg = train_loss / batch_count if batch_count > 0 else 0.0 current_train_acc_avg = train_correct / train_total if train_total > 0 else 0.0 progress_bar.set_postfix(loss=current_train_loss_avg, acc=current_train_acc_avg) # 计算训练F1分数 train_f1 = f1_score(train_true_all, train_pred_all, average='macro', zero_division=0) # 验证阶段 model.eval() val_loss = 0.0 val_correct = 0 val_total = 0 val_pred_all = [] val_true_all = [] val_batch_count = 0 # 验证集批次计数 with torch.no_grad(): progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Validation]") for batch_idx, (features, targets) in enumerate(progress_bar): val_batch_count += 1 # 更新验证集批次计数 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 统计 val_loss += loss.item() _, predicted = outputs.max(1) val_correct += predicted.eq(targets).sum().item() val_total += targets.size(0) # 收集预测和真实标签 val_pred_all.extend(predicted.cpu().numpy()) val_true_all.extend(targets.cpu().numpy()) # 更新进度条 - 使用当前批次平均值 current_val_loss_avg = val_loss / val_batch_count if val_batch_count > 0 else 0.0 current_val_acc_avg = val_correct / val_total if val_total > 0 else 0.0 progress_bar.set_postfix(loss=current_val_loss_avg, acc=current_val_acc_avg) # 计算验证F1分数 val_f1 = f1_score(val_true_all, val_pred_all, average='macro', zero_division=0) # 更新学习率 scheduler.step(val_f1) # 保存历史 - 使用批次计数而不是len(dataloader) # 添加除零保护 avg_train_loss = train_loss / batch_count if batch_count > 0 else 0.0 avg_val_loss = val_loss / val_batch_count if val_batch_count > 0 else 0.0 avg_train_acc = train_correct / train_total if train_total > 0 else 0.0 avg_val_acc = val_correct / val_total if val_total > 0 else 0.0 if val_batch_count == 0: logger.warning(f"Epoch {epoch+1}: 验证集为空或未产生批次,验证损失和准确率将为0。") history['train_loss'].append(avg_train_loss) history['train_acc'].append(avg_train_acc) history['train_f1'].append(train_f1) history['val_loss'].append(avg_val_loss) history['val_acc'].append(avg_val_acc) history['val_f1'].append(val_f1) # 计算每个类别的F1分数 train_f1_per_class = f1_score(train_true_all, train_pred_all, average=None, zero_division=0) val_f1_per_class = f1_score(val_true_all, val_pred_all, average=None, zero_division=0) # 输出本轮结果 logger.info(f"Epoch {epoch+1}/{num_epochs} - " f"Train Loss: {avg_train_loss:.4f}, Train Acc: {avg_train_acc:.4f}, Train F1: {train_f1:.4f}, " f"Val Loss: {avg_val_loss:.4f}, Val Acc: {avg_val_acc:.4f}, Val F1: {val_f1:.4f}") # 输出每个类别的F1分数 logger.info(f"Train F1 per class: {', '.join([f'{f:.4f}' for f in train_f1_per_class])}") logger.info(f"Val F1 per class: {', '.join([f'{f:.4f}' for f in val_f1_per_class])}") # 检查是否是最佳模型 if val_f1 > best_val_f1: best_val_f1 = val_f1 # 保存模型和训练状态 checkpoint = { 'epoch': epoch, 'model_state_dict': model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'best_val_f1': best_val_f1, 'early_stop_counter': early_stop_counter, 'class_weights': class_weights, } torch.save(checkpoint, checkpoint_path) logger.info(f"检查点已保存到 {checkpoint_path},Epoch: {epoch+1}, 验证F1: {val_f1:.4f}") early_stop_counter = 0 else: early_stop_counter += 1 logger.info(f"验证F1未提升,当前计数: {early_stop_counter}/{patience}") # 早停检查 if early_stop_counter >= patience: logger.info(f"早停触发,停止训练。最佳验证F1: {best_val_f1:.4f}") break return history

def evaluate_model(model: nn.Module,
data_loader: DataLoader,
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
output_dir: Optional[str] = None) -> Dict[str, Union[float, np.ndarray]]:
"""
评估模型

text
参数: model (nn.Module): 要评估的模型 data_loader (DataLoader): 数据加载器 device (torch.device): 计算设备 output_dir (str, 可选): 输出目录,如果提供,会保存混淆矩阵图 返回: Dict[str, Union[float, np.ndarray]]: 评估结果 """ logger.info("开始评估模型") model.eval() criterion = nn.CrossEntropyLoss() total_loss = 0.0 all_preds = [] all_targets = [] batch_count = 0 with torch.no_grad(): for features, targets in tqdm(data_loader, desc="评估"): batch_count += 1 # 将数据移动到设备 for phase in features: features[phase] = features[phase].to(device) targets = targets.to(device) # 前向传播 outputs = model(features) loss = criterion(outputs, targets) # 统计 total_loss += loss.item() _, predicted = outputs.max(1) # 收集预测和真实标签 all_preds.extend(predicted.cpu().numpy()) all_targets.extend(targets.cpu().numpy()) # 计算平均损失 - 使用批次计数 avg_loss = total_loss / batch_count if batch_count > 0 else 0.0 if batch_count == 0: logger.warning("评估时数据加载器为空或未产生批次。") # 计算准确率 accuracy = sum(np.array(all_preds) == np.array(all_targets)) / len(all_targets) if len(all_targets) > 0 else 0.0 # 计算F1分数 f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0) # 计算混淆矩阵 cm = confusion_matrix(all_targets, all_preds) # 如果提供了输出目录,保存混淆矩阵图 if output_dir: if not os.path.exists(output_dir): os.makedirs(output_dir) plt.figure(figsize=(10, 8)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues') plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') plt.savefig(os.path.join(output_dir, 'confusion_matrix.png')) plt.close() # 保存分类报告 report_dict = classification_report(all_targets, all_preds, output_dict=True, zero_division=0) with open(os.path.join(output_dir, 'classification_report.json'), 'w') as f: json.dump(report_dict, f, indent=4) report_str = classification_report(all_targets, all_preds, zero_division=0) logger.info(f"分类报告:\n{report_str}") logger.info(f"评估结果 - 损失: {avg_loss:.4f}, 准确率: {accuracy:.4f}, F1分数: {f1:.4f}") return { 'loss': avg_loss, 'accuracy': accuracy, 'f1': f1, 'confusion_matrix': cm, 'predictions': np.array(all_preds), 'targets': np.array(all_targets) }

def plot_training_history(history: Dict[str, List], save_path: Optional[str] = None):
"""
绘制训练历史

text
参数: history (Dict[str, List]): 训练历史 save_path (str, 可选): 保存路径,不提供则显示图像 """ plt.figure(figsize=(15, 10)) # 绘制损失 plt.subplot(2, 2, 1) plt.plot(history['train_loss'], label='训练损失') plt.plot(history['val_loss'], label='验证损失') plt.title('损失') plt.xlabel('轮次') plt.ylabel('损失') plt.legend() plt.grid(True) # 绘制准确率 plt.subplot(2, 2, 2) plt.plot(history['train_acc'], label='训练准确率') plt.plot(history['val_acc'], label='验证准确率') plt.title('准确率') plt.xlabel('轮次') plt.ylabel('准确率') plt.legend() plt.grid(True) # 绘制F1分数 plt.subplot(2, 2, 3) plt.plot(history['train_f1'], label='训练F1') plt.plot(history['val_f1'], label='验证F1') plt.title('F1分数') plt.xlabel('轮次') plt.ylabel('F1') plt.legend() plt.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path) plt.close() else: plt.show()

模型预测函数

def predict_sample(model: nn.Module,
a_phase: str,
b_phase: str,
c_phase: str,
p_phase: str,
target_length: int = 150,
alignment_method: str = 'linear_interpolation',
scalers: Dict[str, Dict[str, float]] = None,
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')) -> Dict[str, Union[int, float]]:
"""
预测单个样本

text
参数: model (nn.Module): 模型 a_phase (str): A相曲线字符串 b_phase (str): B相曲线字符串 c_phase (str): C相曲线字符串 p_phase (str): P相曲线字符串 target_length (int): 目标序列长度 alignment_method (str): 对齐方法 scalers (Dict[str, Dict[str, float]]): 归一化参数 device (torch.device): 计算设备 返回: Dict[str, Union[int, float]]: 预测结果 """ model.eval() # 解析曲线 def parse_curve(curve_str): return [float(val) for val in curve_str.strip('|').split('|')] # 处理曲线 def align_curve(curve, target_length, method): if len(curve) == target_length: return np.array(curve) if method == 'linear_interpolation': original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) elif method == 'jump_point_alignment': return align_by_jump_point_local(curve, target_length) # 使用局部函数避免命名冲突 else: # 默认使用线性插值 original_indices = np.linspace(0, len(curve) - 1, len(curve)) new_indices = np.linspace(0, len(curve) - 1, target_length) return np.interp(new_indices, original_indices, curve) # 跳变点对齐 (局部版本) def align_by_jump_point_local(curve, target_length): curve_array = np.array(curve) jump_idx = find_jump_point_local(curve_array) # 使用局部函数 target_jump_idx = int(target_length * 0.3) left_pad = max(0, target_jump_idx - jump_idx) right_pad = max(0, (target_length - left_pad - len(curve))) if len(curve) + left_pad + right_pad > target_length: trim_left = max(0, jump_idx - target_jump_idx) trim_right = max(0, len(curve) - trim_left - (target_length - left_pad - right_pad)) trimmed_curve = curve_array[trim_left:len(curve)-trim_right] else: trimmed_curve = curve_array padded = np.zeros(target_length) padded[left_pad:left_pad+len(trimmed_curve)] = trimmed_curve return padded # 寻找跳变点 (局部版本) def find_jump_point_local(curve, threshold=0.5): if len(curve) <= 1: # 处理空或单点曲线 return len(curve) // 2 diffs = np.abs(np.diff(curve)) if len(diffs) == 0: # 如果差分后为空 (例如原始曲线只有1个点) return len(curve) // 2 mean_diff = np.mean(diffs) std_diff = np.std(diffs) jump_threshold = mean_diff + threshold * std_diff jump_indices = np.where(diffs > jump_threshold)[0] if len(jump_indices) > 0: return jump_indices[np.argmax(diffs[jump_indices])] + 1 else: return len(curve) // 2 # 归一化曲线 def normalize_curve(curve, phase, scalers): if scalers is None: return curve min_val = scalers[phase]['min'] max_val = scalers[phase]['max'] range_val = max_val - min_val if range_val == 0: return np.zeros_like(curve) return (curve - min_val) / range_val # 处理每个相的曲线 curves_data = { 'A相': parse_curve(a_phase), 'B相': parse_curve(b_phase), 'C相': parse_curve(c_phase), 'P相': parse_curve(p_phase) } # 对齐曲线 aligned_curves_data = { phase: align_curve(curves_data[phase], target_length, alignment_method) for phase in curves_data } # 归一化曲线 if scalers is not None: normalized_curves_data = { phase: normalize_curve(aligned_curves_data[phase], phase, scalers) for phase in aligned_curves_data } else: normalized_curves_data = aligned_curves_data # 转换为张量 features_tensor = { phase: torch.FloatTensor(normalized_curves_data[phase]).unsqueeze(0).to(device) for phase in normalized_curves_data } # 预测 with torch.no_grad(): outputs = model(features_tensor) probabilities = F.softmax(outputs, dim=1) predicted_class = outputs.argmax(dim=1).item() confidence_score = probabilities[0, predicted_class].item() return { 'prediction': predicted_class, 'confidence': confidence_score, 'probabilities': probabilities[0].cpu().numpy().tolist() }

生成样本曲线的图像

def generate_curve_plot(a_phase: str, b_phase: str, c_phase: str, p_phase: str) -> str:
"""
生成曲线图像并返回Base64编码的图像

text
参数: a_phase (str): A相曲线字符串 b_phase (str): B相曲线字符串 c_phase (str): C相曲线字符串 p_phase (str): P相曲线字符串 返回: str: Base64编码的图像 """ # 解析曲线 def parse_curve(curve_str): return [float(val) for val in curve_str.strip('|').split('|')] curves = { 'A相': parse_curve(a_phase), 'B相': parse_curve(b_phase), 'C相': parse_curve(c_phase), 'P相': parse_curve(p_phase) } plt.figure(figsize=(12, 8)) # 创建一个2x2的子图布局 phases_list = ['A相', 'B相', 'C相', 'P相'] for i, phase_name in enumerate(phases_list): plt.subplot(2, 2, i+1) plt.plot(curves[phase_name]) plt.title(f"{phase_name} 曲线") plt.xlabel("时间步") plt.ylabel("值") plt.grid(True) plt.tight_layout() # 将图像保存到内存中的BytesIO对象 buffer = io.BytesIO() plt.savefig(buffer, format='png') plt.close() # 获取图像的Base64编码 buffer.seek(0) img_str = base64.b64encode(buffer.read()).decode('utf-8') return img_str

FastAPI应用

app = FastAPI(title="故障曲线分类预测")
templates = Jinja2Templates(directory="templates")

class InputData(BaseModel):
"""API输入数据模型"""
a_phase: str
b_phase: str
c_phase: str
p_phase: str

全局模型和归一化器

global_model = None
global_scalers = None

def load_model_for_api(model_path: str = 'best_model.pth', num_classes: int = 9):
"""
加载已训练的模型 (API专用,避免与主训练流程中的model变量混淆)

text
参数: model_path (str): 模型文件路径 num_classes (int): 类别数 返回: nn.Module: 加载的模型 """ global global_model if global_model is not None: return global_model try: loaded_model = LSTMCNN(num_classes=num_classes) try: checkpoint = torch.load(model_path, map_location=device, weights_only=False) except FileNotFoundError: logger.error(f"模型文件未找到 (API): {model_path}") raise except Exception as e: logger.error(f"加载检查点时出错 (API): {str(e)}") raise # 加载模型权重 if 'model_state_dict' in checkpoint: loaded_model.load_state_dict(checkpoint['model_state_dict']) else: loaded_model.load_state_dict(checkpoint) logger.info(f"从 {model_path} 加载模型成功 (API), 类别数: {num_classes}") except Exception as e: logger.error(f"加载模型失败 (API): {str(e)}") raise loaded_model = loaded_model.to(device) loaded_model.eval() global_model = loaded_model # 赋值给全局变量 return global_model

def load_scalers_for_api(scalers_path: str = 'scalers.pkl'):
"""
加载归一化器 (API专用)

text
参数: scalers_path (str): 归一化器文件路径 返回: Dict[str, Dict[str, float]]: 加载的归一化参数 """ global global_scalers # 修改这里以匹配全局变量名 if global_scalers is not None: return global_scalers try: import pickle with open(scalers_path, 'rb') as f: loaded_scalers = pickle.load(f) # 使用不同的局部变量名 logger.info(f"从 {scalers_path} 加载归一化器成功 (API)") global_scalers = loaded_scalers # 赋值给全局变量 except FileNotFoundError: logger.error(f"归一化器文件未找到: {scalers_path} (API). 将不使用归一化器。") global_scalers = None # 明确设置为None except Exception as e: logger.error(f"加载归一化器失败 (API): {str(e)}") global_scalers = None # 明确设置为None return global_scalers

@app.get("/", response_class=HTMLResponse)
async def get_index(request: Request):
"""
返回主页
"""
return templates.TemplateResponse("index.html", {"request": request})

@app.get("/index2", response_class=HTMLResponse)
async def get_index2(request: Request):
"""
返回增强版主页
"""
return templates.TemplateResponse("index2.html", {"request": request})

@app.post("/predict")
async def predict_api_endpoint(data: InputData): # 重命名函数以避免与内部predict函数混淆
"""
预测API端点

text
参数: data (InputData): 输入数据 返回: Dict: 预测结果 """ try: # 加载模型和归一化器 # 假设模型和scalers在启动时或第一次调用时加载 # 如果尚未加载,则加载它们 # 注意:这里的num_classes需要与训练时一致,可以考虑从配置文件读取或从模型检查点获取 # 暂时硬编码为9,如果您的类别数不同,请修改 current_model = load_model_for_api(num_classes=9) # 确保这里的num_classes正确 current_scalers = load_scalers_for_api() if current_model is None: raise HTTPException(status_code=500, detail="模型未能加载") # 预测 result = predict_sample( model=current_model, a_phase=data.a_phase, b_phase=data.b_phase, c_phase=data.c_phase, p_phase=data.p_phase, scalers=current_scalers, # 传递加载的scalers device=device ) # 生成图像 plot_img = generate_curve_plot(data.a_phase, data.b_phase, data.c_phase, data.p_phase) return { "success": True, "prediction": result["prediction"], "confidence": result["confidence"], "probabilities": result["probabilities"], "plot_image": plot_img } except HTTPException as http_exc: # 重新抛出HTTPException raise http_exc except Exception as e: logger.exception("预测时出错 (API)") raise HTTPException(status_code=500, detail=f"内部服务器错误: {str(e)}")

def get_data_splits_indices(file_path: str,
train_ratio: float,
val_ratio: float,
random_state: int = 42) -> Tuple[List[int], List[int], List[int]]:
"""
计算并返回用于训练、验证和测试集的打乱后的行索引。
索引是相对于数据文件的行号(0-based, 不包括表头)。
"""
logger.info("正在为流式数据集确定行索引...")
try:
with open(file_path, 'r', encoding='utf-8') as f:
# 计算总行数 (跳过表头)
total_lines = sum(1 for _ in f) - 1
if total_lines <= 0:
raise ValueError("数据文件为空或只有表头。")
except FileNotFoundError:
logger.error(f"数据文件未找到: {file_path}")
raise
except Exception as e:
logger.error(f"读取文件行数时出错: {e}")
raise

text
all_indices = list(range(total_lines)) # 使用numpy进行可复现的打乱 rng = np.random.RandomState(random_state) rng.shuffle(all_indices) # 计算每个集合的样本数 train_count = int(total_lines * train_ratio) val_count = int(total_lines * val_ratio) train_indices = all_indices[:train_count] val_indices = all_indices[train_count : train_count + val_count] test_indices = all_indices[train_count + val_count:] logger.info(f"总数据行数: {total_lines}") logger.info(f"训练集索引数: {len(train_indices)}, 验证集索引数: {len(val_indices)}, 测试集索引数: {len(test_indices)}") # 确保没有重叠且覆盖所有索引 assert len(set(train_indices) | set(val_indices) | set(test_indices)) == total_lines assert len(set(train_indices) & set(val_indices)) == 0 assert len(set(train_indices) & set(test_indices)) == 0 assert len(set(val_indices) & set(test_indices)) == 0 return train_indices, val_indices, test_indices

主函数

def main():
"""主函数,用于训练和评估模型"""
try:
# 设置输出目录
output_dir = "model_output"
os.makedirs(output_dir, exist_ok=True)

text
# 定义数据集文件路径 data_path = 'curve_data_pivot_new.csv' if not os.path.exists(data_path): logger.error(f"数据文件 {data_path} 未找到。请确保文件路径正确。") return # 首先进行一次文件扫描,获取基本统计信息和类别数量 logger.info("扫描数据集获取基本信息...") # 使用纯CSV读取而不是pandas,更节省内存 class_counts = {} curve_lengths = {'A相': [], 'B相': [], 'C相': [], 'P相': []} curve_mins = {'A相': float('inf'), 'B相': float('inf'), 'C相': float('inf'), 'P相': float('inf')} curve_maxs = {'A相': float('-inf'), 'B相': float('-inf'), 'C相': float('-inf'), 'P相': float('-inf')} # 仅采样部分数据计算统计值 total_rows = 0 sample_interval = 100 # 每100行采样一次 with open(data_path, 'r', encoding='utf-8') as f: reader = csv.reader(f) header = next(reader) # 跳过表头 for i, row in enumerate(tqdm(reader, desc="扫描数据集")): total_rows += 1 # 统计类别分布 try: label = int(row[3]) if label in class_counts: class_counts[label] += 1 else: class_counts[label] = 1 except (IndexError, ValueError) as e: logger.warning(f"解析标签时跳过行 {i+1}: {e} - Row: {row}") continue # 采样部分数据计算统计值 if i % sample_interval == 0: for j, phase in enumerate(['A相', 'B相', 'C相', 'P相']): try: curve_str = row[j+4] values = [float(val) for val in curve_str.strip('|').split('|') if val] # 确保val不为空 if not values: continue # 如果解析后为空列表,跳过 curve_lengths[phase].append(len(values)) # 更新最小值和最大值 min_val = min(values) max_val = max(values) curve_mins[phase] = min(curve_mins[phase], min_val) curve_maxs[phase] = max(curve_maxs[phase], max_val) except (IndexError, ValueError) as e: logger.warning(f"解析曲线数据时跳过行 {i+1}, 相 {phase}: {e} - Row: {row}") continue if total_rows == 0: logger.error("数据文件为空或未能读取任何行。") return if not class_counts: logger.error("未能从数据文件中解析任何类别标签。请检查数据格式。") return # 计算每个相的长度统计和归一化参数 scalers = {} for phase in ['A相', 'B相', 'C相', 'P相']: # 确保在max/min仍为inf/-inf时(即没有有效数据点)提供默认值 scalers[phase] = { 'min': curve_mins[phase] if curve_mins[phase] != float('inf') else 0.0, 'max': curve_maxs[phase] if curve_maxs[phase] != float('-inf') else 1.0 } # 如果min和max相同,调整max以避免除零 if scalers[phase]['min'] == scalers[phase]['max']: scalers[phase]['max'] += 1e-6 # 添加一个小值 # 使用计算得到的统计信息 logger.info(f"数据集总行数: {total_rows}") logger.info(f"类别分布: {class_counts}") # 确定 num_classes 基于实际观察到的最大类别索引 if not class_counts: logger.error("无法确定类别数量,因为类别分布为空。") return # 或者设置一个默认值,但这可能不正确 num_classes = max(class_counts.keys()) + 1 logger.info(f"根据数据推断的类别数量 (max_label + 1): {num_classes}") # 保存归一化器 import pickle scalers_path = os.path.join(output_dir, 'scalers.pkl') with open(scalers_path, 'wb') as f: pickle.dump(scalers, f) logger.info(f"归一化器已保存到 {scalers_path}") # 计算类别权重 class_weights_tensor = torch.zeros(num_classes) # 使用推断的num_classes # 确保所有类别(从0到num_classes-1)都有一个权重,即使它们在数据中不存在 for class_idx in range(num_classes): class_count = class_counts.get(class_idx, 0) # 获取计数,如果不存在则为0 if class_count > 0: # 使用实际观察到的类别数量 len(class_counts) 进行加权 class_weights_tensor[class_idx] = total_rows / (len(class_counts) * class_count) else: # 对于数据中未出现的类别,可以设置权重为0或1,或基于其他策略 # 设置为0意味着损失函数中不会考虑这些类别(如果它们意外出现) # 设置为1意味着它们具有“平均”权重(如果它们是可能的但稀有的类别) class_weights_tensor[class_idx] = 1.0 # 或0.0,取决于策略 logger.warning(f"类别 {class_idx} 在数据中未找到,权重设置为1.0。") logger.info(f"类别权重: {class_weights_tensor}") # 获取打乱后的数据划分索引 train_indices, val_indices, test_indices = get_data_splits_indices( file_path=data_path, train_ratio=0.7, val_ratio=0.15, random_state=42 ) # 创建流式训练集 train_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=train_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=True, scalers=scalers, buffer_size=1000 ) # 创建流式验证集 val_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=val_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=False, scalers=scalers, buffer_size=1000 ) # 创建流式测试集 test_dataset = StreamingFaultDataset( file_path=data_path, assigned_indices_list=test_indices, batch_size=32, target_length=150, alignment_method='jump_point_alignment', is_train=False, scalers=scalers, buffer_size=1000 ) # 创建数据加载器 # DataLoader的batch_size=None表示数据集本身产生批次 train_loader = DataLoader( train_dataset, batch_size=None, num_workers=0 # 多进程在Windows上可能需要特殊处理,暂时设为0 ) val_loader = DataLoader( val_dataset, batch_size=None, num_workers=0 ) test_loader = DataLoader( test_dataset, batch_size=None, num_workers=0 ) # 创建模型 model_instance = LSTMCNN( # 使用不同的变量名 input_channels=4, seq_length=150, num_classes=num_classes, cnn_out_channels=128, lstm_hidden_size=set_lstm_hidden_size, lstm_layers=set_lstm_layers, dropout=set_dropout, bidirectional=True ) # 打印模型结构 logger.info(f"模型结构: {model_instance}") # 训练模型 logger.info("开始训练模型...") checkpoint_path = os.path.join(output_dir, 'best_model.pth') history = train_model( model=model_instance, train_loader=train_loader, val_loader=val_loader, num_epochs=100, learning_rate=0.001, weight_decay=0.01, patience=20, checkpoint_path=checkpoint_path, class_weights=class_weights_tensor, device=device ) # 绘制训练历史 history_path = os.path.join(output_dir, 'training_history.png') plot_training_history(history, save_path=history_path) logger.info(f"训练历史已保存到 {history_path}") # 加载最佳模型进行评估 best_model_for_eval = LSTMCNN( # 使用不同的变量名 input_channels=4, seq_length=150, num_classes=num_classes, cnn_out_channels=128, lstm_hidden_size=set_lstm_hidden_size, lstm_layers=set_lstm_layers, dropout=set_dropout, bidirectional=True ) if not os.path.exists(checkpoint_path): logger.error(f"最佳模型检查点 {checkpoint_path} 未找到。跳过评估。") else: checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) # 兼容旧的只保存模型权重的格式和新的包含 'model_state_dict' 的格式 if 'model_state_dict' in checkpoint: best_model_for_eval.load_state_dict(checkpoint['model_state_dict']) else: best_model_for_eval.load_state_dict(checkpoint) logger.info(f"已从 {checkpoint_path} 加载最佳模型权重进行评估。") best_model_for_eval = best_model_for_eval.to(device) # 评估模型 logger.info("评估模型在测试集上的表现...") evaluation_dir = os.path.join(output_dir, 'evaluation') os.makedirs(evaluation_dir, exist_ok=True) test_results = evaluate_model( model=best_model_for_eval, data_loader=test_loader, device=device, output_dir=evaluation_dir ) logger.info(f"测试结果 - 损失: {test_results['loss']:.4f}, 准确率: {test_results['accuracy']:.4f}, F1分数: {test_results['f1']:.4f}") # 保存评估结果 evaluation_results_data = { # 使用不同的变量名 'loss': float(test_results['loss']), 'accuracy': float(test_results['accuracy']), 'f1': float(test_results['f1']), 'confusion_matrix': test_results['confusion_matrix'].tolist() } with open(os.path.join(evaluation_dir, 'evaluation_results.json'), 'w') as f: json.dump(evaluation_results_data, f, indent=4) logger.info("模型训练和评估完成!") except Exception as e: logger.exception("训练过程中出错") # raise # 在生产环境中可能不想重新抛出,而是记录并优雅退出

API服务器启动函数

def run_api_server(host="0.0.0.0", port=8000, num_classes_api=9): # 添加num_classes参数
"""
启动API服务器

text
参数: host (str): 主机 port (int): 端口 num_classes_api (int): API模型加载时使用的类别数 """ # 在服务器启动时预加载模型和scalers logger.info("正在为API预加载模型和归一化器...") load_model_for_api(model_path=os.path.join("model_output", 'best_model.pth')) load_scalers_for_api(scalers_path=os.path.join("model_output", 'scalers.pkl')) logger.info("模型和归一化器预加载完成 (API)。") import uvicorn uvicorn.run(app, host=host, port=port)

def configure_chinese_font():
"""配置中文字体,根据系统可用字体选择合适的中文字体"""
import matplotlib.font_manager as fm
import platform

text
# 获取所有可用的字体 font_names = [f.name for f in fm.fontManager.ttflist] # 按优先级排序的中文字体列表 chinese_fonts = [ 'Noto Sans CJK SC', # 思源黑体简体中文 'Noto Sans CJK JP', # 思源黑体日文(包含中文) 'Noto Sans CJK TC', # 思源黑体繁体中文 'WenQuanYi Micro Hei', # 文泉驿微米黑 'Microsoft YaHei', # 微软雅黑 'SimHei', # 黑体 'SimSun', # 宋体 'Arial Unicode MS', # Arial Unicode ] # 检查系统中可用的中文字体 available_chinese_fonts = [font for font in chinese_fonts if font in font_names] if available_chinese_fonts: # 使用找到的第一个可用中文字体 plt.rcParams['font.sans-serif'] = available_chinese_fonts print(f"使用中文字体: {available_chinese_fonts[0]}") else: # 如果没有找到任何预定义的中文字体,使用系统默认字体 print("警告: 未找到预设的中文字体,可能无法正确显示中文") plt.rcParams['axes.unicode_minus'] = False plt.rcParams['font.family'] = 'sans-serif' # 返回使用的字体,便于调试 return available_chinese_fonts[0] if available_chinese_fonts else None

如果直接运行此脚本

if name == "main":
configure_chinese_font()
# 在Windows上确保多进程支持
if os.name == 'nt': # 更明确地检查操作系统
multiprocessing.freeze_support()

text
import argparse parser = argparse.ArgumentParser(description="故障曲线分类预测系统") parser.add_argument("--mode", type=str, choices=["train", "serve"], default="train", help="运行模式: train (训练模型) 或 serve (启动API服务)") parser.add_argument("--port", type=int, default=8000, help="API服务端口") parser.add_argument("--host", type=str, default="0.0.0.0", help="API服务主机") # 添加一个参数来传递类别数给API服务,因为这可能与训练时动态确定的不同 # 或者,更好的做法是从保存的检查点或配置文件中读取 args = parser.parse_args() if args.mode == "train": main() elif args.mode == "serve": # 确保输出目录存在,以便API可以加载模型/scalers output_dir = "model_output" if not os.path.exists(os.path.join(output_dir, 'best_model.pth')): logger.warning(f"API模式:模型文件 {os.path.join(output_dir, 'best_model.pth')} 未找到。请先训练模型。") # 可以选择退出或尝试继续(如果API逻辑能处理模型不存在的情况) if not os.path.exists(os.path.join(output_dir, 'scalers.pkl')): logger.warning(f"API模式:归一化文件 {os.path.join(output_dir, 'scalers.pkl')} 未找到。预测可能不准确。") run_api_server(host=args.host, port=args.port)

目前有什么问题

思考

Thought for 8 seconds

回答

下面把我在快速审阅中发现的主要问题与改进建议分为 5 个维度,先说“硬 Bug / 隐患”,再说“结构与可维护性”,最后到“效率、风格与部署”。若有任何点想深挖,告诉我咱们再展开。


1 ️⃣ 运行期硬伤 & 潜在崩溃点

代码位置问题描述影响 & 修复思路
文件内容整体整个脚本从 # 导入所需的库if __name__ == "__main__": 完整重复了一遍。Python 会把后半段视为对前面定义的二次覆盖,运行时不会报错但会造成 加载时间 ×2、阅读困惑、潜在版本不一致。直接删除重复部分;最好拆分模块(见 2.1)。
load_model_for_api() 默认 model_path='best_model.pth' ,而训练脚本把模型保存到 model_output/best_model.pthAPI 启动后找不到权重 → 500。统一路径或在 API 启动时把 model_path 设为同一值。
DataParallel + checkpoint:保存时用 model.module.state_dict(),但加载 checkpoint 时先把权重直接 load_state_dict未包 DataParallel 的模型,如果当时是多 GPU 训练而现在单 GPU 推理,会 键不匹配多 GPU 训练、单 GPU 推理会崩溃。保存时加 strict=False 或在加载端判断后包 DataParallel 再加载。
StreamingFaultDataset.__iter__row = line.strip().split(',') 没有考虑逗号转义/字段缺失;遇到异常只 continue,容易 静默丢数据准确率掉、难以追踪。csv.readerpandas.read_csv(chunksize)解析。
find_jump_point_local 对空/单点曲线已做保护,但在 _parse_and_align 若某条曲线字段为空串仍会进 float(val) 报错。解析异常。在 parse 前先判空;或用 `if not curve_str.strip('
API 端 num_classes 写死为 9;训练端 num_classes = max(class_counts)+1类别数变化会导致维度不符线上推理直接崩。将类别数写入 checkpoint / config json,加载时读取。

2 ️⃣ 结构与可维护性

  1. 拆分模块
    • 建议把数据集(datasets.py)、模型(models.py)、训练脚本(train.py)、API(api.py) 分开。
    • 便于单元测试,也能避免当前“整段复制”‍这种风险。
  2. 配置集中化
    • seq_length、batch_size、hidden_size、文件路径 等散落在函数签名与全局变量中,推荐用 yaml/inipydantic.BaseSettings 管理。
  3. 日志与进度条冲突
    • tqdmlogging 同时打印时会乱行;可以用 tqdm.write 或设置 logging.basicConfig(force=True) 并让 tqdm 显示 leave=False

3 ️⃣ 效率与资源

场景现状建议
训练数据加载StreamingFaultDataset 仍然逐行读取 CSV,然后在 Python 层解析字符串 → CPU bound。改为先用 pandas 分块读,或把原始 CSV 转成 Parquet / HDF5,再用二进制格式加速 I/O。
大文件扫描统计首次扫描用 csv.reader +sample_interval=100,仍遍历全部行。直接 nrows 随机采样或在 ETL 阶段离线生成 scalers,减少启动等待。
GPU 利用CNN+LSTM → batch_size 32 × 序列 150 ×4 通道 OK,但 DataLoader num_workers=0 限制了吞吐。在 Linux/WSL2 环境下把 num_workers 调到 CPU 核数(Windows 需 spawn_start_method 额外设置)。

4 ️⃣ 风格 & 可读性

  • 多处 import ... 重复;并且一次性 import 40+ 第三方库,API 镜像会很大
    建议把仅训练端用到的包(如 seaborn, tqdm, multiprocessing)拆进 requirements-train.txt,推理镜像只装 torch + fastapi + uvicorn
  • 函数 & 变量中英文混用(jump_point_alignmentA相),可以保留注释中文,代码标识符全部英文,有助协作。
  • _augment_batch 里四种增强方法直接写 if-else,可抽象为策略列表,方便启停/扩展。

5 ️⃣ 部署/API 细节

  1. FastAPI 模板目录 templates 未包含在仓库就会 404。可选:
    • 移除页面,REST-only;或
    • 用前端框架分离。
  2. 线程安全global_model/global_scalers 单进程 OK;若后续用 Gunicorn 多 worker,需要把加载逻辑改为 on_startup.
  3. CORS / 验证:暂无跨域与鉴权设置,线上直接暴露预测接口需注意。

✔️ 下一步怎么做?

  • 第一优先:去掉重复代码,修好 checkpoint 路径 & DataParallel 加载兼容。
  • 第二优先:把数据解析 / API 推理逻辑抽离到独立文件,写 1~2 个单元测试确保主要流程走通。
  • 第三优先:根据你的数据规模与部署环境,决定是否要改用 Parquet、优化 DataLoader、或引入更轻量的模型。

有任何点想深入(例如如何改写 DataParallel->DDP,怎样把 scalers 写进 checkpoint,或是 API 镜像精简)告诉我,咱们针对性动手!

分享这个问答