TCAS模型分析与改进
Criado em: 20 de janeiro de 2025
Criado em: 20 de janeiro de 2025
学习以下代码,请对代码的逻辑作出解释,并给出目前模型的不足之处和可以进一步改进的地方
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
from typing import List, Tuple, Dict
from torch.utils.tensorboard import SummaryWriter
import datetime
import os
NMI2FEET = 6076.11549
KT2FPS = 1.6878098571
G = 9.80665
FPS2KT = 1 / KT2FPS
FPM2FPS = 1 / 60
CONFLICT_THRESHOLD_H = 2000.0 # 水平距离小于2000 ft,就认为有冲突趋势
CONFLICT_THRESHOLD_V = 600.0 # 垂直距离小于600 ft,就认为有冲突趋势
DEFAULT_OWN_AIR_SPEED = 200 * KT2FPS
ACTIONS = {
'COC': 0, # Clear of Conflict
'DNC2000': 1,
'DND2000': 2,
'DNC1000': 3,
'DND1000': 4,
'DNC500': 5,
'DND500': 6,
'DNC': 7,
'DND': 8,
'MCL': 9,
'MDES': 10,
'CL1500': 11,
'DES1500': 12,
'SCL1500': 13,
'SDES1500': 14,
'SCL2500': 15,
'SDES2500': 16
}
action_number_map = [
'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND',
'MCL', 'MDES', 'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500'
]
vertical_rate_limit = {
'COC': [-100.0, 100.0],
'DNC2000': [-100.0, 2000.0 / 60],
'DND2000': [-2000.0 / 60, 100.0],
'DNC1000': [-100.0, 1000.0 / 60],
'DND1000': [-1000.0 / 60, 100.0],
'DNC500': [-100.0, 500.0 / 60],
'DND500': [-500.0 / 60, 100.0],
'DNC': [-100.0, 0.0],
'DND': [0.0, 100.0],
'MCL': [0.0, 100.0],
'MDES': [-100.0, 0.0],
'CL1500': [1500.0 / 60, 100.0],
'DES1500': [-100.0, -1500.0 / 60],
'SCL1500': [1500.0 / 60, 100.0],
'SDES1500': [-100.0, -1500.0 / 60],
'SCL2500': [2500.0 / 60, 100.0],
'SDES2500': [-100.0, -2500.0 / 60]
}
ACCELS = [0] + [1 / 4] * 12 + [1 / 3] * 4
action_info = {
'COC': {'direction': 0, 'level': 0},
# 预防性RA
'DNC2000': {'direction': 0, 'level': 1},
'DND2000': {'direction': 0, 'level': 1},
'DNC1000': {'direction': 0, 'level': 1},
'DND1000': {'direction': 0, 'level': 1},
'DNC500': {'direction': 0, 'level': 1},
'DND500': {'direction': 0, 'level': 1},
'DNC': {'direction': 0, 'level': 1},
'DND': {'direction': 0, 'level': 1},
'MCL': {'direction': +1, 'level': 1},
'MDES': {'direction': -1, 'level': 1},
text# 纠正性RA 'CL1500': {'direction': +1, 'level': 2}, 'DES1500': {'direction': -1, 'level': 2}, # 更强纠正RA 'SCL1500': {'direction': +1, 'level': 3}, 'SDES1500': {'direction': -1, 'level': 3}, 'SCL2500': {'direction': +1, 'level': 3}, 'SDES2500': {'direction': -1, 'level': 3},
}
allowed_tcas_transitions = {
# 无RA → 可转任意
'COC': [
'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND', 'MCL', 'MDES',
'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500'
],
# 预防性RA → 可以保持自己或升级到纠正性(同向), 不建议直接反向
'DNC': ['DNC', 'DNC1000', 'DNC2000', 'CL1500', 'SCL1500', 'SCL2500', 'COC'],
'DND': ['DND', 'DND1000', 'DND2000', 'DES1500', 'SDES1500', 'SDES2500', 'COC'],
'DNC2000': ['DNC2000', 'DNC1000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'],
'DND2000': ['DND2000', 'DND1000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'],
'DNC1000': ['DNC1000', 'DNC2000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'],
'DND1000': ['DND1000', 'DND2000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'],
'DNC500': ['DNC500', 'CL1500', 'SCL1500', 'SCL2500', 'COC'],
'DND500': ['DND500', 'DES1500', 'SDES1500', 'SDES2500', 'COC'],
'MCL': ['MCL', 'CL1500', 'SCL1500', 'SCL2500', 'COC'],
'MDES': ['MDES', 'DES1500', 'SDES1500', 'SDES2500', 'COC'],
text# 纠正性RA(level=2) → 可以保持或加强(到level=3),或回到预防性/COC 'CL1500': ['CL1500', 'SCL1500', 'SCL2500', 'COC', 'MCL', 'DNC', 'DNC2000', 'DNC1000'], 'DES1500': ['DES1500', 'SDES1500', 'SDES2500', 'COC', 'MDES', 'DND', 'DND2000', 'DND1000'], # 更强纠正RA(level=3) → 保持自己/可弱化回到level=2或更低 'SCL1500': ['SCL1500', 'SCL2500', 'CL1500', 'MCL', 'COC'], 'SDES1500': ['SDES1500', 'SDES2500', 'DES1500', 'MDES', 'COC'], 'SCL2500': ['SCL2500', 'SCL1500', 'CL1500', 'MCL', 'COC'], 'SDES2500': ['SDES2500', 'SDES1500', 'DES1500', 'MDES', 'COC'],
}
class Config:
"""配置类"""
textdef __init__(self): # 时间步长 self.time_step = 1.0 # 本机性能限制 self.max_climb_rate = 6000.0 / 60.0 self.max_descent_rate = -6000.0 / 60.0 self.max_accel_up = 0.4 * G self.max_accel_down = -0.3 * G # 冲突/软惩罚阈值 self.conflict_threshold_h = CONFLICT_THRESHOLD_H self.conflict_threshold_v = CONFLICT_THRESHOLD_V # 惩罚/奖励系数 self.conflict_penalty = 1.0 self.ttc_penalty_factor = 0.5 self.distance_weight = 0.0001 # =========== 新增: TCAS2相关惩罚权重 =========== # 非法转移大额惩罚 self.illegal_transition_penalty = 5.0 # 反向(RA Reversal)惩罚 self.sense_reversal_penalty = 0.01 # 加强/减弱惩罚 self.strengthen_penalty = 0.005 self.weaken_penalty = 0.001 # DQN相关 self.target_update_method = 'soft' # or 'hard' self.tau = 0.001 # NMAC阈值 self.nmac_threshold_v = 100 self.nmac_threshold_h = 500 # 飞机数量 self.min_intruders = 90 self.max_intruders = 100 # 网络参数 self.hidden_dim = 128 self.num_heads = 4 # 训练参数 self.batch_size = 32 self.gamma = 0.99 self.learning_rate = 1e-4 self.memory_size = 10000 self.target_update_freq = 10 self.epsilon_start = 1.0 self.epsilon_end = 0.05 self.epsilon_decay = 0.995 # 归一化参数 self.normalization = { 'x_range': 5 * NMI2FEET, 'y_range': 3 * NMI2FEET, 'rel_h': 3000.0, 'v_speed': 100.0, 'horizontal_speed': 600 * KT2FPS, 'angle': 2 * np.pi, 'v_speed_o': 100.0 }
import math
def compute_tau_horizontal(intruder: dict,
own_hspeed: float) -> float:
"""
计算本机(own)与入侵机(intruder)在水平面上的Tau(到最近点的时间)。
如果不可能产生水平CPA(彼此在远离或平行),则返回 math.inf。
text参数: intruder: { 'x_range': float, # 当前相对X(入侵机 - 本机), ft 'y_range': float, # 当前相对Y(入侵机 - 本机), ft 'horizontal_speed': float, # 入侵机水平速度(标量), ft/s 'angle': float # 入侵机速度方向(相对水平面, 弧度或者自行约定) } own_hspeed: 本机水平速度(标量, ft/s), 此示例中假设指向 +X 方向 返回: tau_h: float 水平上的 Tau, 单位秒. 若水平距离不会减少, 返回 math.inf """ # 1) 计算相对位置 rel_x = intruder['x_range'] rel_y = intruder['y_range'] dist_h = math.sqrt(rel_x ** 2 + rel_y ** 2) if dist_h <= 1e-5: # 如果当前就非常接近(甚至重叠), 也可以视作 math.inf 或 0 return math.inf # 2) 计算相对速度 # 入侵机 vx, vy vx_int = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy_int = intruder['horizontal_speed'] * math.sin(intruder['angle']) # 本机 vx_own = own_hspeed, vy_own = 0 (假设面向 +X) rel_vx = vx_int - own_hspeed rel_vy = vy_int # 3) 计算 “closing_speed” # closing_speed = - ( (rel_x*rel_vx + rel_y*rel_vy) / dist_h ) # 若 > 0 => 正在接近; <=0 => 远离或平行 closing_speed = - ((rel_x * rel_vx + rel_y * rel_vy) / dist_h) if closing_speed <= 1e-5: return math.inf # 4) tau_h = dist_h / closing_speed tau_h = dist_h / closing_speed return tau_h
class MultiAircraftEnv:
"""多飞机环境 + TCAS2动作逻辑 + 奖励改进"""
textdef __init__(self, config: Config): self.config = config self.escape_threshold = 6 * NMI2FEET self.next_intruder_id = 1 self.initial_altitude = 15000.0 # 设置初始高度为15000ft # 添加action变动计数器 self.action_change_count = 0 # 总的变动次数 def initialize_intruder(self) -> Dict: x_range = np.random.uniform(3, 5) * np.random.choice([-1, 1]) * NMI2FEET y_range = np.random.uniform(0, 3) * np.random.choice([-1, 1]) * NMI2FEET # 生成角度(随意) angle = 0 if x_range > 0 and y_range > 0: angle = np.random.uniform(1, 2) * np.pi elif x_range > 0 and y_range < 0: angle = np.random.uniform(0, 1) * np.pi elif x_range < 0 and y_range > 0: angle = np.random.uniform(3 / 2, 2) * np.pi else: angle = np.random.uniform(0, 1 / 2) * np.pi rel_h = np.random.uniform(-3000, 3000) v_speed = np.random.uniform(-100, 100) horizontal_speed = np.random.uniform(200, 600) * KT2FPS intruder_state = { 'x_range': x_range, 'y_range': y_range, 'rel_h': rel_h, 'v_speed': v_speed, 'horizontal_speed': horizontal_speed, 'angle': angle, 'id': self.next_intruder_id, 'escaped': False, 'horizontal_range': math.sqrt(x_range ** 2 + y_range ** 2) } self.next_intruder_id += 1 return intruder_state def reset(self) -> Tuple[torch.Tensor, torch.Tensor]: self.action_change_count = 0 # 重置计数器 num_intruders = random.randint(self.config.min_intruders, self.config.max_intruders) self.own_state = { 'altitude': self.initial_altitude, # 添加高度状态但不用于训练 'v_speed_o': 0.0, # 或随机 'prev_action': 0 } self.intruder_states = [] for _ in range(num_intruders): intruder = self.initialize_intruder() self.intruder_states.append(intruder) return self.get_state_tensors() def get_state_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: own_tensor = torch.tensor([ self.own_state['v_speed_o'] / self.config.normalization['v_speed_o'], self.own_state['prev_action'] ], dtype=torch.float32) active_intruders = [i for i in self.intruder_states if not i['escaped']] if not active_intruders: intruder_tensor = torch.zeros((0, 6)) else: intruder_tensor = torch.tensor([ [ it['rel_h'] / self.config.normalization['rel_h'], it['v_speed'] / self.config.normalization['v_speed'], it['x_range'] / self.config.normalization['x_range'], it['y_range'] / self.config.normalization['y_range'], it['horizontal_speed'] / self.config.normalization['horizontal_speed'], it['angle'] / self.config.normalization['angle'] ] for it in active_intruders ], dtype=torch.float32) return own_tensor, intruder_tensor def clamp_own_vertical_speed(self, v_speed: float) -> float: if v_speed > self.config.max_climb_rate: return self.config.max_climb_rate if v_speed < self.config.max_descent_rate: return self.config.max_descent_rate return v_speed def compute_ttc(self, intruder) -> float: vx = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy = intruder['horizontal_speed'] * math.sin(intruder['angle']) dist = intruder['horizontal_range'] # 相对速度(本机假设水平速度=200kts) rel_vx = vx - DEFAULT_OWN_AIR_SPEED rel_vy = vy if dist <= 0: return 9999.0 closing_speed = -((intruder['x_range'] * rel_vx + intruder['y_range'] * rel_vy) / dist) if closing_speed <= 0: return 9999.0 return dist / closing_speed def get_vertical_rate(self, action: int, is_own: bool, acceleration: float, current_rate: float) -> float: ra_name = action_number_map[action] rate_limit = vertical_rate_limit[ra_name].copy() if is_own: if ra_name == 'MCL': rate_limit[0] = current_rate if ra_name == 'MDES': rate_limit[1] = current_rate next_rate = current_rate + acceleration * self.config.time_step # next_rate = max(rate_limit[0], min(next_rate, rate_limit[1])) next_rate = self.clamp_own_vertical_speed(next_rate) else: next_rate = current_rate + acceleration * self.config.time_step next_rate = np.clip(next_rate, -100, 100) # if is_own: # print("current speed: ", current_rate, "-> next speed: ", next_rate, "acceleration: ", acceleration, "rate_limit: ", rate_limit, "action: ", ra_name) return next_rate def step(self, action: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], float, bool]: prev_action = self.own_state['prev_action'] prev_str = action_number_map[prev_action] current_str = action_number_map[action] # 统计action变动 if action != prev_action: self.action_change_count += 1 # ============== # 1) 检查TCAS转移 # ============== total_reward = 0.0 # 若不在允许表中,则惩罚 if current_str not in allowed_tcas_transitions[prev_str]: total_reward -= self.config.illegal_transition_penalty # 方向/等级 判断 dir_prev = action_info[prev_str]['direction'] dir_now = action_info[current_str]['direction'] lvl_prev = action_info[prev_str]['level'] lvl_now = action_info[current_str]['level'] # sense reversal (爬升变下降或下降变爬升) if dir_prev != 0 and dir_now != 0 and (dir_prev * dir_now < 0): total_reward -= self.config.sense_reversal_penalty # strengthen / weaken if dir_prev == dir_now and dir_now != 0: # 同方向 if lvl_now > lvl_prev: # 加强 total_reward -= self.config.strengthen_penalty elif lvl_now < lvl_prev: # 减弱 total_reward -= self.config.weaken_penalty # 更新本机垂直速度 accel = ACCELS[action] * G prev_v_speed_own = self.own_state['v_speed_o'] self.own_state['v_speed_o'] = self.get_vertical_rate(action, True, accel, prev_v_speed_own) # 使用平均垂直速度更新高度 avg_v_speed = (prev_v_speed_own + self.own_state['v_speed_o']) / 2.0 self.own_state['altitude'] += avg_v_speed * self.config.time_step self.own_state['prev_action'] = action nmac_occurred = False for intruder in self.intruder_states: if intruder['escaped']: continue v_accel = np.random.uniform(self.config.max_accel_down, self.config.max_accel_up) prev_v_speed_int = intruder['v_speed'] intruder['v_speed'] = self.get_vertical_rate(0, False, v_accel, prev_v_speed_int) old_range = intruder['horizontal_range'] vx = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy = intruder['horizontal_speed'] * math.sin(intruder['angle']) # xinew = xiold + vx xonew = xoold + 200 # x_range = xiold -xoold + vx - 200 = x_range_old + (vx - 200) intruder['x_range'] += (vx * self.config.time_step - DEFAULT_OWN_AIR_SPEED * self.config.time_step) intruder['y_range'] += vy * self.config.time_step own_avg_v = (prev_v_speed_own + self.own_state['v_speed_o']) / 2.0 int_avg_v = (prev_v_speed_int + intruder['v_speed']) / 2.0 intruder['rel_h'] += (int_avg_v - own_avg_v) * self.config.time_step intruder['horizontal_range'] = math.sqrt(intruder['x_range'] ** 2 + intruder['y_range'] ** 2) # 计算水平tau: tau_h = compute_tau_horizontal(intruder, own_hspeed=DEFAULT_OWN_AIR_SPEED) intruder['tau_h'] = tau_h # 存起来, 以备查看 # NMAC检测 if abs(intruder['rel_h']) < self.config.nmac_threshold_v and intruder[ 'horizontal_range'] < self.config.nmac_threshold_h: nmac_occurred = True total_reward -= 10.0 # 软惩罚:距离/TTC if abs(intruder['rel_h']) < self.config.conflict_threshold_v and intruder[ 'horizontal_range'] < self.config.conflict_threshold_h: dist_penalty = 0.0 if intruder['horizontal_range'] > 0: dist_penalty += self.config.distance_weight / intruder['horizontal_range'] if abs(intruder['rel_h']) > 0: dist_penalty += self.config.distance_weight / abs(intruder['rel_h']) total_reward -= dist_penalty ttc = self.compute_ttc(intruder) if ttc < 9999.0: total_reward -= self.config.ttc_penalty_factor / (ttc + 1e-5) # 逃脱判断 if intruder['horizontal_range'] > self.escape_threshold and intruder[ 'horizontal_range'] > old_range or math.isinf(intruder['tau_h']): intruder['escaped'] = True total_reward += 0.1 # (修改) 对 total_reward 做裁剪, 避免出现极端大的负值 # total_reward = max(-10.0, min(total_reward, 10.0)) # done判定 active_intruders = [i for i in self.intruder_states if not i['escaped']] done = nmac_occurred or (len(active_intruders) == 0) return self.get_state_tensors(), total_reward, done def get_statistics(self) -> Dict: active_count = len([i for i in self.intruder_states if not i['escaped']]) escaped_count = len(self.intruder_states) - active_count return { 'active_intruders': active_count, 'escaped_intruders': escaped_count, 'total_intruders': len(self.intruder_states), 'action_changes': self.action_change_count # 添加action变动次数 }
def record_episode(env, episode_data, filename):
"""记录一个episode的状态数据到文件"""
with open(filename, 'w') as f:
# 写入初始配置
config = {
"initial_altitude": env.initial_altitude, # 初始高度
"time_step": env.config.time_step,
"own_speed": DEFAULT_OWN_AIR_SPEED * FPS2KT, # 节
"total_steps": len(episode_data)
}
f.write(f"#CONFIG\n{str(config)}\n")
f.write("#DATA\n")
text# 写入每个时间步的状态 for state in episode_data: f.write(f"{str(state)}\n")
class CASNetwork(nn.Module):
"""使用注意力机制处理可变数量入侵机的防撞系统网络"""
textdef __init__(self, config: Config): super().__init__() self.config = config # 本机状态编码器 self.own_encoder = nn.Sequential( nn.Linear(2, config.hidden_dim), nn.ReLU() ) # 入侵机状态编码器 self.intruder_encoder = nn.Sequential( nn.Linear(6, config.hidden_dim), nn.ReLU() ) # 注意力层 self.attention = nn.MultiheadAttention( config.hidden_dim, config.num_heads, batch_first=True ) # 输出层 self.output_layer = nn.Sequential( nn.Linear(config.hidden_dim * 2, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, len(ACTIONS)) ) def forward(self, own_state, intruder_states): own_encoded = self.own_encoder(own_state) if intruder_states.size(1) == 0: intruder_feature = torch.zeros_like(own_encoded) else: bsz, num_intr, _ = intruder_states.size() intr_enc = self.intruder_encoder( intruder_states.view(-1, 6) ).view(bsz, num_intr, -1) attn_output, _ = self.attention( own_encoded.unsqueeze(1), intr_enc, intr_enc ) intruder_feature = attn_output.squeeze(1) combined = torch.cat([own_encoded, intruder_feature], dim=1) return self.output_layer(combined)
class DQNAgent:
"""Double DQN智能体"""
textdef __init__(self, config: Config): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.policy_net = CASNetwork(config).to(self.device) self.target_net = CASNetwork(config).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=config.learning_rate) self.memory = deque(maxlen=config.memory_size) self.epsilon = config.epsilon_start def pad_intruder_states(self, intruder_states: torch.Tensor, max_intruders: int) -> torch.Tensor: batch_size = 1 if len(intruder_states.shape) == 2 else intruder_states.shape[0] current_intruders = intruder_states.shape[-2] if current_intruders < max_intruders: padding = torch.zeros( (batch_size, max_intruders - current_intruders, 6), device=intruder_states.device ) return torch.cat([intruder_states, padding], dim=-2) return intruder_states[:, :max_intruders, :] def select_action(self, own_state: torch.Tensor, intruder_states: torch.Tensor) -> int: if random.random() < self.epsilon: return random.randrange(len(ACTIONS)) with torch.no_grad(): own_state = own_state.unsqueeze(0).to(self.device) intruder_states = intruder_states.unsqueeze(0).to(self.device) intruder_states = self.pad_intruder_states(intruder_states, self.config.max_intruders) q_values = self.policy_net(own_state, intruder_states) return q_values.argmax(dim=1).item() def store_transition(self, s_own, s_intr, action, reward, s2_own, s2_intr, done): self.memory.append((s_own, s_intr, action, reward, s2_own, s2_intr, done)) def train_step(self): if len(self.memory) < self.config.batch_size: return None batch = random.sample(self.memory, self.config.batch_size) s_own, s_intr, acts, rws, s2_own, s2_intr, dns = zip(*batch) s_own = torch.stack(s_own).to(self.device) acts = torch.tensor(acts, dtype=torch.long).to(self.device) rws = torch.tensor(rws, dtype=torch.float).to(self.device) s2_own = torch.stack(s2_own).to(self.device) dns = torch.tensor(dns, dtype=torch.float).to(self.device) max_intruders = max(x.size(0) for x in s_intr) max_intruders = min(max_intruders, self.config.max_intruders) intr_batch = [] for s_i in s_intr: padded = self.pad_intruder_states(s_i.unsqueeze(0), max_intruders).squeeze(0) intr_batch.append(padded) intr_batch = torch.stack(intr_batch).to(self.device) intr2_batch = [] for s2_i in s2_intr: padded2 = self.pad_intruder_states(s2_i.unsqueeze(0), max_intruders).squeeze(0) intr2_batch.append(padded2) intr2_batch = torch.stack(intr2_batch).to(self.device) q_curr = self.policy_net(s_own, intr_batch) q_curr_a = q_curr.gather(1, acts.unsqueeze(1)) with torch.no_grad(): next_actions = self.policy_net(s2_own, intr2_batch).argmax(1) q_next = self.target_net(s2_own, intr2_batch) q_next_a = q_next.gather(1, next_actions.unsqueeze(1)).squeeze(1) target_q = rws + (1 - dns) * self.config.gamma * q_next_a loss_fn = nn.SmoothL1Loss() loss = loss_fn(q_curr_a.squeeze(), target_q) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() self.epsilon = max(self.config.epsilon_end, self.epsilon * self.config.epsilon_decay) return loss.item() def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def soft_update_target_network(self, tau=0.001): for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()): tp.data.copy_(tau * pp.data + (1.0 - tau) * tp.data)
def train_cas(config: Config, num_episodes: int = 5000):
env = MultiAircraftEnv(config)
agent = DQNAgent(config)
textcurrent_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_training_{current_time}') writer = SummaryWriter(log_dir) episode_rewards = [] nmac_counts = [] for episode in range(num_episodes): own_state, intr_states = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 episode_loss = 0.0 num_training_steps = 0 initial_intruders = len(env.intruder_states) # 记录这个episode的所有状态 episode_data = [] while True: # 记录当前状态 current_state = { "own_state": { "altitude": env.own_state['altitude'], # 累积高度变化 "v_speed": env.own_state['v_speed_o'], "action": action_number_map[env.own_state['prev_action']] }, "intruders": [{ "id": intr['id'], "x_range": intr['x_range'], "y_range": intr['y_range'], "rel_h": intr['rel_h'], "v_speed": intr['v_speed'], "horizontal_speed": intr['horizontal_speed'], "angle": intr['angle'], "escaped": intr['escaped'] } for intr in env.intruder_states] } # 添加到episode_data列表 episode_data.append(current_state) action = agent.select_action(own_state, intr_states) (next_own, next_intr), reward, done = env.step(action) agent.store_transition(own_state, intr_states, action, reward, next_own, next_intr, done) if len(agent.memory) >= config.batch_size: loss = agent.train_step() if loss is not None: episode_loss += loss num_training_steps += 1 ep_reward += reward steps += 1 if reward <= -100.0: nmac_count += 1 own_state, intr_states = next_own, next_intr # 软更新 agent.soft_update_target_network(tau=config.tau) if done: current_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') filename = f'episode_data.txt' record_episode(env, episode_data, filename) break episode_rewards.append(ep_reward) nmac_counts.append(nmac_count) # 日志记录 writer.add_scalar('1.Rewards/Episode_Reward', ep_reward, episode) if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) writer.add_scalar('1.Rewards/Average_Reward', avg_r, episode) writer.add_scalar('2.Safety/NMAC_Count', nmac_count, episode) if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) writer.add_scalar('2.Safety/Average_NMAC', avg_n, episode) writer.add_scalar('3.Training/Steps', steps, episode) writer.add_scalar('3.Training/Epsilon', agent.epsilon, episode) if num_training_steps > 0: writer.add_scalar('3.Training/Average_Loss', episode_loss / num_training_steps, episode) stats = env.get_statistics() writer.add_scalar('6.Intruders/Initial_Count', initial_intruders, episode) writer.add_scalar('6.Intruders/Active_Count', stats['active_intruders'], episode) writer.add_scalar('6.Intruders/Escaped_Count', stats['escaped_intruders'], episode) writer.add_scalar('7.Own_State/Vertical_Speed', env.own_state['v_speed_o'], episode) # 添加action变动记录 writer.add_scalar('7.Own_State/Action_Changes', stats['action_changes'], episode) # 打印 print(f"\nEpisode {episode + 1}/{num_episodes}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}, Eps: {agent.epsilon:.3f}") print(f"Initial Intruders: {initial_intruders}") if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) print(f"AvgReward(100): {avg_r:.2f}") if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) print(f"AvgNMAC(100): {avg_n:.3f}") if num_training_steps > 0: print(f"AvgLoss: {episode_loss / num_training_steps:.6f}") print(f"Active Intruders: {stats['active_intruders']}, Escaped Intruders: {stats['escaped_intruders']}") print(f"Action Changes: {stats['action_changes']}") # 打印action变动次数 print("-" * 50) writer.close() return agent, episode_rewards, nmac_counts
def evaluate_cas(agent: DQNAgent, config: Config, num_episodes: int = 100):
env = MultiAircraftEnv(config)
textcurrent_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_evaluation_{current_time}') writer = SummaryWriter(log_dir) total_rewards = [] nmac_counts = [] step_counts = [] old_eps = agent.epsilon agent.epsilon = 0.0 for ep in range(num_episodes): own_state, intr_states = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 action_counts = torch.zeros(len(ACTIONS)) while True: action = agent.select_action(own_state, intr_states) action_counts[action] += 1 (next_own, next_intr), reward, done = env.step(action) ep_reward += reward steps += 1 if reward <= -100.0: nmac_count += 1 own_state, intr_states = next_own, next_intr if done: break writer.add_scalar('Evaluation/Episode_Reward', ep_reward, ep) writer.add_scalar('Evaluation/NMAC_Count', nmac_count, ep) writer.add_scalar('Evaluation/Steps', steps, ep) action_probs = action_counts / steps for name, idx in ACTIONS.items(): writer.add_scalar(f'Evaluation/Action_{name}', action_probs[idx].item(), ep) if hasattr(env, 'get_statistics'): st = env.get_statistics() writer.add_scalar('Evaluation/Active_Intruders', st['active_intruders'], ep) writer.add_scalar('Evaluation/Escaped_Intruders', st['escaped_intruders'], ep) total_rewards.append(ep_reward) nmac_counts.append(nmac_count) step_counts.append(steps) print(f"\nEvaluation Episode {ep + 1}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}") avg_r = sum(total_rewards) / len(total_rewards) avg_n = sum(nmac_counts) / len(nmac_counts) avg_stp = sum(step_counts) / len(step_counts) writer.add_hparams( {'eval_episodes': num_episodes}, { 'hparam/avg_reward': avg_r, 'hparam/avg_nmac': avg_n, 'hparam/avg_steps': avg_stp } ) writer.close() agent.epsilon = old_eps print("\nEvaluation Results:") print(f"Avg Reward: {avg_r:.2f}") print(f"Avg NMAC: {avg_n:.3f}") print(f"Avg Steps: {avg_stp:.1f}") return avg_r, avg_n, avg_stp
if name == "main":
config = Config()
textprint("Starting training...") agent, ep_rewards, nmacs = train_cas(config, num_episodes=5000) # 保存模型 save_dir = 'models' os.makedirs(save_dir, exist_ok=True) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') model_path = os.path.join(save_dir, f'cas_model_{current_time}.pth') torch.save({ 'policy_net_state_dict': agent.policy_net.state_dict(), 'target_net_state_dict': agent.target_net.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict(), 'config': config, }, model_path) # 评估 print("\nStarting evaluation...") evaluate_cas(agent, config, num_episodes=1000)
以下内容从整体的框架出发,先对代码的主要逻辑做一个概括性说明,然后再讨论该模型和训练过程可能存在的一些不足之处,最后提出若干可改进的思路。
常量定义
TCAS 相关动作与转移限制
ACTIONS
与 action_number_map
给出了 17 种不同类型的动作,每种动作在真实 TCAS 逻辑中可以视作不同级别(预防性、纠正性、更强纠正性)的爬升/下降指令。vertical_rate_limit
定义了每种动作对应的垂直速率上下限。action_info
对每个动作标注了 direction
(+1 爬升、-1 下降、0 不动)和 level
(0=无 RA,1=预防性 RA,2=纠正性 RA,3=更强纠正性 RA)。allowed_tcas_transitions
定义了前后两个 TCAS 动作之间可否合法转移。比如从预防性 RA 可以升级到纠正性 RA,但不建议直接反向等。环境类 MultiAircraftEnv
altitude
、本机垂直速度 v_speed_o
、上一时刻动作 prev_action
。x_range
, y_range
),相对高度(rel_h
),水平速度 horizontal_speed
及其方向 angle
,垂直速度 v_speed
,以及是否“逃逸”标记 escaped
等。reset
):
min_intruders
到 max_intruders
)入侵机,每架入侵机都带一定的随机初始位置、速度、方向、相对高度等。initial_altitude
(示例中 15000ft),初始垂直速度可设 0 或其它随机值。get_state_tensors
):
v_speed_o
以及 prev_action
的索引)转换成 1D Tensor;[number_of_intruders, 6]
。step
):
allowed_tcas_transitions
检查是否合法;TTC
(time to collision)也会有罚分;(own_tensor, intruder_tensor)
,本步的奖励 reward
,以及回合是否结束 done
。CASNetwork
)整体结构:
Linear -> ReLU
编码到 hidden_dim
。Linear -> ReLU
编码到 hidden_dim
。query
,入侵机的编码向量序列视作 key, value
;squeeze
回到批大小 [B, hidden_dim]
。这样就可以在入侵机数量可变时,通过注意力对入侵机进行“聚合”或“加权”,得到一个“全局入侵机表征”,再与本机特征拼接输出动作价值。
DQNAgent
)主要结构
policy_net
与一个 target_net
(Double DQN);replay buffer
存储经验,采用均匀随机采样来训练;epsilon-greedy
策略进行动作选择;训练方法(train_cas
函数)
env.reset()
生成新的环境、入侵机;done
,记录日志并退出回合。评估(evaluate_cas
函数)
环境简单且带有较多假设
奖励设计与标量化的复杂性
对动作转移的硬编码规则(allowed_tcas_transitions
)
训练稳定性
clip_grad_norm_
)。多机注意力网络的输入处理
max_intruders
),虽然一定程度上解决了可变数量输入的问题,但若同屏飞机数非常大,可能会导致大维度注意力计算耗时上升;若同屏飞机数大于 max_intruders
则会截断信息,潜在造成决策遗漏。动作维度与真实决策需求之间可能仍存在差距
增强环境复杂度与真实性
改进奖励函数
在 DQN 方面采用更多强化学习改进技巧
注意力网络的结构优化
更加灵活的动作空间设计
更全面的安全验证
该示例代码实现了一个多飞机环境、多头注意力网络的 DQN 算法,用于模拟 TCAS 水平的冲突防护/回避决策。它包含了离散化的动作空间、硬编码的动作转移规则以及基本的奖励与冲突检测逻辑,能在小规模环境下跑通并得到初步的学习效果。然而,由于环境简单、奖励设计仍相对粗糙、算法本身较为基础,若要更贴近真实世界场景,需要在环境建模、奖励函数、网络结构、强化学习算法等方面做更多拓展和优化。
请你针对增强环境复杂度与真实性和改进奖励函数在原有代码的基础上做出改进,并给出说明和改进后的完整代码
下面给出一个示例性的改进版本,将在环境复杂度和奖励函数两个方面进行了改动。示例中的改动并不一定完整或最优,但可以作为进一步扩展和实验的起点。主要改动包括:
环境复杂度
turn_rate
)更新自身 angle
;horizontal_speed
)也会在一定范围内随机波动,以模拟真实场景中飞机的速度变化。奖励函数
以下为完整示例代码(包括原有类和函数),只是在MultiAircraftEnv
和 train_cas
的部分地方加注释标识了主要的改动位置,以便读者快速对照。
pythonimport math import numpy as np import torch import torch.nn as nn import torch.optim as optim import random from collections import deque from typing import List, Tuple, Dict from torch.utils.tensorboard import SummaryWriter import datetime import os # ====================== # 常量定义 # ====================== NMI2FEET = 6076.11549 KT2FPS = 1.6878098571 G = 9.80665 FPS2KT = 1 / KT2FPS FPM2FPS = 1 / 60 # 可以加一个更宽松的冲突检测阈值,用来施加软惩罚 CONFLICT_THRESHOLD_H = 2000.0 # 水平距离小于2000 ft,就认为有冲突趋势 CONFLICT_THRESHOLD_V = 600.0 # 垂直距离小于600 ft,就认为有冲突趋势 DEFAULT_OWN_AIR_SPEED = 200 * KT2FPS # --------------------- # 动作空间 (与原代码一致) # --------------------- ACTIONS = { 'COC': 0, # Clear of Conflict 'DNC2000': 1, 'DND2000': 2, 'DNC1000': 3, 'DND1000': 4, 'DNC500': 5, 'DND500': 6, 'DNC': 7, 'DND': 8, 'MCL': 9, 'MDES': 10, 'CL1500': 11, 'DES1500': 12, 'SCL1500': 13, 'SDES1500': 14, 'SCL2500': 15, 'SDES2500': 16 } action_number_map = [ 'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND', 'MCL', 'MDES', 'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500' ] # --------------------- # 为每个动作定义 垂直速率限制 (与原代码一致) # --------------------- vertical_rate_limit = { 'COC': [-100.0, 100.0], 'DNC2000': [-100.0, 2000.0 / 60], 'DND2000': [-2000.0 / 60, 100.0], 'DNC1000': [-100.0, 1000.0 / 60], 'DND1000': [-1000.0 / 60, 100.0], 'DNC500': [-100.0, 500.0 / 60], 'DND500': [-500.0 / 60, 100.0], 'DNC': [-100.0, 0.0], 'DND': [0.0, 100.0], 'MCL': [0.0, 100.0], 'MDES': [-100.0, 0.0], 'CL1500': [1500.0 / 60, 100.0], 'DES1500': [-100.0, -1500.0 / 60], 'SCL1500': [1500.0 / 60, 100.0], 'SDES1500': [-100.0, -1500.0 / 60], 'SCL2500': [2500.0 / 60, 100.0], 'SDES2500': [-100.0, -2500.0 / 60] } # --------------------- # 加速度表(原逻辑) # --------------------- ACCELS = [0] + [1 / 4] * 12 + [1 / 3] * 4 # =================================================== # 1) 为TCAS动作设计"方向"和"等级" # =================================================== action_info = { 'COC': {'direction': 0, 'level': 0}, # 预防性RA 'DNC2000': {'direction': 0, 'level': 1}, 'DND2000': {'direction': 0, 'level': 1}, 'DNC1000': {'direction': 0, 'level': 1}, 'DND1000': {'direction': 0, 'level': 1}, 'DNC500': {'direction': 0, 'level': 1}, 'DND500': {'direction': 0, 'level': 1}, 'DNC': {'direction': 0, 'level': 1}, 'DND': {'direction': 0, 'level': 1}, 'MCL': {'direction': +1, 'level': 1}, 'MDES': {'direction': -1, 'level': 1}, # 纠正性RA 'CL1500': {'direction': +1, 'level': 2}, 'DES1500': {'direction': -1, 'level': 2}, # 更强纠正RA 'SCL1500': {'direction': +1, 'level': 3}, 'SDES1500': {'direction': -1, 'level': 3}, 'SCL2500': {'direction': +1, 'level': 3}, 'SDES2500': {'direction': -1, 'level': 3}, } # =================================================== # 2) 定义一个“TCAS动作转移表”来限制前后动作合法性 # =================================================== allowed_tcas_transitions = { # 无RA → 可转任意 'COC': [ 'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND', 'MCL', 'MDES', 'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500' ], # 预防性RA → 可以保持自己或升级到纠正性(同向), 不建议直接反向 'DNC': ['DNC', 'DNC1000', 'DNC2000', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND': ['DND', 'DND1000', 'DND2000', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC2000': ['DNC2000', 'DNC1000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND2000': ['DND2000', 'DND1000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC1000': ['DNC1000', 'DNC2000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND1000': ['DND1000', 'DND2000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC500': ['DNC500', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND500': ['DND500', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'MCL': ['MCL', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'MDES': ['MDES', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], # 纠正性RA(level=2) → 可以保持或加强(到level=3),或回到预防性/COC 'CL1500': ['CL1500', 'SCL1500', 'SCL2500', 'COC', 'MCL', 'DNC', 'DNC2000', 'DNC1000'], 'DES1500': ['DES1500', 'SDES1500', 'SDES2500', 'COC', 'MDES', 'DND', 'DND2000', 'DND1000'], # 更强纠正RA(level=3) → 保持自己/可弱化回到level=2或更低 'SCL1500': ['SCL1500', 'SCL2500', 'CL1500', 'MCL', 'COC'], 'SDES1500': ['SDES1500', 'SDES2500', 'DES1500', 'MDES', 'COC'], 'SCL2500': ['SCL2500', 'SCL1500', 'CL1500', 'MCL', 'COC'], 'SDES2500': ['SDES2500', 'SDES1500', 'DES1500', 'MDES', 'COC'], } class Config: """配置类""" def __init__(self): # 时间步长 self.time_step = 1.0 # 本机性能限制 self.max_climb_rate = 6000.0 / 60.0 self.max_descent_rate = -6000.0 / 60.0 self.max_accel_up = 0.4 * G self.max_accel_down = -0.3 * G # 冲突/软惩罚阈值 self.conflict_threshold_h = CONFLICT_THRESHOLD_H self.conflict_threshold_v = CONFLICT_THRESHOLD_V # 惩罚/奖励系数 self.conflict_penalty = 1.0 self.ttc_penalty_factor = 0.5 self.distance_weight = 0.0001 # =========== 新增: TCAS2相关惩罚权重 =========== # 非法转移大额惩罚 self.illegal_transition_penalty = 5.0 # 反向(RA Reversal)惩罚 self.sense_reversal_penalty = 0.01 # 加强/减弱惩罚 self.strengthen_penalty = 0.005 self.weaken_penalty = 0.001 # DQN相关 self.target_update_method = 'soft' # or 'hard' self.tau = 0.001 # NMAC阈值 self.nmac_threshold_v = 100 self.nmac_threshold_h = 500 # 飞机数量 self.min_intruders = 90 self.max_intruders = 100 # 网络参数 self.hidden_dim = 128 self.num_heads = 4 # 训练参数 self.batch_size = 32 self.gamma = 0.99 self.learning_rate = 1e-4 self.memory_size = 10000 self.target_update_freq = 10 self.epsilon_start = 1.0 self.epsilon_end = 0.05 self.epsilon_decay = 0.995 # 归一化参数 self.normalization = { 'x_range': 5 * NMI2FEET, 'y_range': 3 * NMI2FEET, 'rel_h': 3000.0, 'v_speed': 100.0, 'horizontal_speed': 600 * KT2FPS, 'angle': 2 * np.pi, 'v_speed_o': 100.0 } # ============= 以下为环境/奖励改进的新超参 ============= self.heading_change_prob = 0.05 # 每个时间步入侵机有一定概率改变转向速率 self.turn_rate_range = (-0.03, 0.03) # 转向速率范围(弧度/秒),可自行调整 self.speed_change_prob = 0.1 # 每个时间步入侵机有一定概率改变水平加速度 self.max_horizontal_acc = 5.0 # 入侵机水平加速度(ft/s^2)上下限,可自行调整 self.min_hspeed = 100 * KT2FPS # 入侵机水平速度最小值 self.max_hspeed = 600 * KT2FPS # 入侵机水平速度最大值 self.survival_reward = 0.01 # 每步存活奖励(若无NMAC) self.large_vs_penalty_factor = 0.0001 # 对本机垂直速度的惩罚系数 def compute_tau_horizontal(intruder: dict, own_hspeed: float) -> float: """ 计算水平上的Tau(到最近点的时间)。若不可能产生水平CPA则返回 math.inf。 """ rel_x = intruder['x_range'] rel_y = intruder['y_range'] dist_h = math.sqrt(rel_x ** 2 + rel_y ** 2) if dist_h <= 1e-5: return math.inf vx_int = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy_int = intruder['horizontal_speed'] * math.sin(intruder['angle']) rel_vx = vx_int - own_hspeed rel_vy = vy_int closing_speed = -((rel_x * rel_vx + rel_y * rel_vy) / dist_h) if closing_speed <= 1e-5: return math.inf tau_h = dist_h / closing_speed return tau_h class MultiAircraftEnv: """多飞机环境 + TCAS2动作逻辑 + 奖励改进""" def __init__(self, config: Config): self.config = config self.escape_threshold = 6 * NMI2FEET self.next_intruder_id = 1 self.initial_altitude = 15000.0 # 初始高度 self.action_change_count = 0 # 统计动作变动次数 # 新增:设置最大仿真步数,防止死循环训练 # 可根据需求自己调整 self.max_steps_per_episode = 300 self.current_step = 0 def initialize_intruder(self) -> Dict: x_range = np.random.uniform(3, 5) * np.random.choice([-1, 1]) * NMI2FEET y_range = np.random.uniform(0, 3) * np.random.choice([-1, 1]) * NMI2FEET # 生成随机角度 angle = np.random.uniform(0, 2 * np.pi) rel_h = np.random.uniform(-3000, 3000) v_speed = np.random.uniform(-100, 100) horizontal_speed = np.random.uniform(200, 600) * KT2FPS # 新增:初始转向速率 turn_rate = np.random.uniform(*self.config.turn_rate_range) intruder_state = { 'x_range': x_range, 'y_range': y_range, 'rel_h': rel_h, 'v_speed': v_speed, 'horizontal_speed': horizontal_speed, 'angle': angle, 'turn_rate': turn_rate, 'id': self.next_intruder_id, 'escaped': False, 'horizontal_range': math.sqrt(x_range ** 2 + y_range ** 2) } self.next_intruder_id += 1 return intruder_state def reset(self) -> Tuple[torch.Tensor, torch.Tensor]: self.action_change_count = 0 self.current_step = 0 num_intruders = random.randint(self.config.min_intruders, self.config.max_intruders) self.own_state = { 'altitude': self.initial_altitude, 'v_speed_o': 0.0, 'prev_action': 0 } self.intruder_states = [] for _ in range(num_intruders): intruder = self.initialize_intruder() self.intruder_states.append(intruder) return self.get_state_tensors() def get_state_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: own_tensor = torch.tensor([ self.own_state['v_speed_o'] / self.config.normalization['v_speed_o'], self.own_state['prev_action'] ], dtype=torch.float32) active_intruders = [i for i in self.intruder_states if not i['escaped']] if not active_intruders: intruder_tensor = torch.zeros((0, 6)) else: intruder_tensor = torch.tensor([ [ it['rel_h'] / self.config.normalization['rel_h'], it['v_speed'] / self.config.normalization['v_speed'], it['x_range'] / self.config.normalization['x_range'], it['y_range'] / self.config.normalization['y_range'], it['horizontal_speed'] / self.config.normalization['horizontal_speed'], it['angle'] / self.config.normalization['angle'] ] for it in active_intruders ], dtype=torch.float32) return own_tensor, intruder_tensor def clamp_own_vertical_speed(self, v_speed: float) -> float: if v_speed > self.config.max_climb_rate: return self.config.max_climb_rate if v_speed < self.config.max_descent_rate: return self.config.max_descent_rate return v_speed def compute_ttc(self, intruder) -> float: vx = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy = intruder['horizontal_speed'] * math.sin(intruder['angle']) dist = intruder['horizontal_range'] rel_vx = vx - DEFAULT_OWN_AIR_SPEED rel_vy = vy if dist <= 0: return 9999.0 closing_speed = -((intruder['x_range'] * rel_vx + intruder['y_range'] * rel_vy) / dist) if closing_speed <= 0: return 9999.0 return dist / closing_speed def get_vertical_rate(self, action: int, is_own: bool, acceleration: float, current_rate: float) -> float: ra_name = action_number_map[action] rate_limit = vertical_rate_limit[ra_name].copy() if is_own: if ra_name == 'MCL': rate_limit[0] = current_rate if ra_name == 'MDES': rate_limit[1] = current_rate next_rate = current_rate + acceleration * self.config.time_step next_rate = self.clamp_own_vertical_speed(next_rate) else: next_rate = current_rate + acceleration * self.config.time_step next_rate = np.clip(next_rate, -100, 100) return next_rate def step(self, action: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], float, bool]: self.current_step += 1 prev_action = self.own_state['prev_action'] prev_str = action_number_map[prev_action] current_str = action_number_map[action] # 统计action变动 total_reward = 0.0 if action != prev_action: self.action_change_count += 1 # 1) 检查TCAS转移合法性 if current_str not in allowed_tcas_transitions[prev_str]: total_reward -= self.config.illegal_transition_penalty # 2) 方向/等级 判断 dir_prev = action_info[prev_str]['direction'] dir_now = action_info[current_str]['direction'] lvl_prev = action_info[prev_str]['level'] lvl_now = action_info[current_str]['level'] # sense reversal (爬升变下降或下降变爬升) if dir_prev != 0 and dir_now != 0 and (dir_prev * dir_now < 0): total_reward -= self.config.sense_reversal_penalty # strengthen / weaken if dir_prev == dir_now and dir_now != 0: if lvl_now > lvl_prev: total_reward -= self.config.strengthen_penalty elif lvl_now < lvl_prev: total_reward -= self.config.weaken_penalty # 3) 更新本机垂直速度与高度 accel = ACCELS[action] * G prev_v_speed_own = self.own_state['v_speed_o'] self.own_state['v_speed_o'] = self.get_vertical_rate(action, True, accel, prev_v_speed_own) avg_v_speed = (prev_v_speed_own + self.own_state['v_speed_o']) / 2.0 self.own_state['altitude'] += avg_v_speed * self.config.time_step self.own_state['prev_action'] = action # ------------------------------- # 新增:对“大垂直速度”惩罚,避免极端机动 # ------------------------------- total_reward -= self.config.large_vs_penalty_factor * abs(self.own_state['v_speed_o']) nmac_occurred = False for intruder in self.intruder_states: if intruder['escaped']: continue # 入侵机随机更改转向速率 if random.random() < self.config.heading_change_prob: intruder['turn_rate'] = np.random.uniform(*self.config.turn_rate_range) # 入侵机更新水平速度(模拟加减速) if random.random() < self.config.speed_change_prob: horizontal_acc = np.random.uniform(-self.config.max_horizontal_acc, self.config.max_horizontal_acc) # 将此加速度作用一小段时间(1秒),更新 intruder['horizontal_speed'] intruder['horizontal_speed'] += horizontal_acc * self.config.time_step # 限制速度范围 intruder['horizontal_speed'] = np.clip( intruder['horizontal_speed'], self.config.min_hspeed, self.config.max_hspeed ) # 随着每步推进,intruder angle 根据turn_rate发生细微变化 intruder['angle'] += intruder['turn_rate'] * self.config.time_step # 更新垂直速度 v_accel = np.random.uniform(self.config.max_accel_down, self.config.max_accel_up) prev_v_speed_int = intruder['v_speed'] intruder['v_speed'] = self.get_vertical_rate(0, False, v_accel, prev_v_speed_int) old_range = intruder['horizontal_range'] vx = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy = intruder['horizontal_speed'] * math.sin(intruder['angle']) # 更新相对位置 intruder['x_range'] += (vx * self.config.time_step - DEFAULT_OWN_AIR_SPEED * self.config.time_step) intruder['y_range'] += vy * self.config.time_step own_avg_v = (prev_v_speed_own + self.own_state['v_speed_o']) / 2.0 int_avg_v = (prev_v_speed_int + intruder['v_speed']) / 2.0 intruder['rel_h'] += (int_avg_v - own_avg_v) * self.config.time_step intruder['horizontal_range'] = math.sqrt(intruder['x_range'] ** 2 + intruder['y_range'] ** 2) # 计算水平tau tau_h = compute_tau_horizontal(intruder, own_hspeed=DEFAULT_OWN_AIR_SPEED) intruder['tau_h'] = tau_h # NMAC检测 if abs(intruder['rel_h']) < self.config.nmac_threshold_v and intruder['horizontal_range'] < self.config.nmac_threshold_h: nmac_occurred = True total_reward -= 10.0 # 软惩罚:距离/TTC if abs(intruder['rel_h']) < self.config.conflict_threshold_v and intruder['horizontal_range'] < self.config.conflict_threshold_h: dist_penalty = 0.0 if intruder['horizontal_range'] > 0: dist_penalty += self.config.distance_weight / intruder['horizontal_range'] if abs(intruder['rel_h']) > 0: dist_penalty += self.config.distance_weight / abs(intruder['rel_h']) total_reward -= dist_penalty ttc = self.compute_ttc(intruder) if ttc < 9999.0: total_reward -= self.config.ttc_penalty_factor / (ttc + 1e-5) # 逃脱判断 if intruder['horizontal_range'] > self.escape_threshold and intruder['horizontal_range'] > old_range or math.isinf(intruder['tau_h']): intruder['escaped'] = True total_reward += 0.1 # 每步若未发生NMAC,给一点生存奖励 if not nmac_occurred: total_reward += self.config.survival_reward # 回合结束条件 active_intruders = [i for i in self.intruder_states if not i['escaped']] done = nmac_occurred or (len(active_intruders) == 0) # 也可设置若到达最大步数则结束 if self.current_step >= self.max_steps_per_episode: done = True return self.get_state_tensors(), total_reward, done def get_statistics(self) -> Dict: active_count = len([i for i in self.intruder_states if not i['escaped']]) escaped_count = len(self.intruder_states) - active_count return { 'active_intruders': active_count, 'escaped_intruders': escaped_count, 'total_intruders': len(self.intruder_states), 'action_changes': self.action_change_count } def record_episode(env, episode_data, filename): """记录一个episode的状态数据到文件""" with open(filename, 'w') as f: # 写入初始配置 config = { "initial_altitude": env.initial_altitude, "time_step": env.config.time_step, "own_speed": DEFAULT_OWN_AIR_SPEED * FPS2KT, # 节 "total_steps": len(episode_data) } f.write(f"#CONFIG\n{str(config)}\n") f.write("#DATA\n") # 写入每个时间步的状态 for state in episode_data: f.write(f"{str(state)}\n") class CASNetwork(nn.Module): """使用注意力机制处理可变数量入侵机的防撞系统网络""" def __init__(self, config: Config): super().__init__() self.config = config # 本机状态编码器 self.own_encoder = nn.Sequential( nn.Linear(2, config.hidden_dim), nn.ReLU() ) # 入侵机状态编码器 self.intruder_encoder = nn.Sequential( nn.Linear(6, config.hidden_dim), nn.ReLU() ) # 注意力层 self.attention = nn.MultiheadAttention( config.hidden_dim, config.num_heads, batch_first=True ) # 输出层 self.output_layer = nn.Sequential( nn.Linear(config.hidden_dim * 2, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, len(ACTIONS)) ) def forward(self, own_state, intruder_states): own_encoded = self.own_encoder(own_state) if intruder_states.size(1) == 0: intruder_feature = torch.zeros_like(own_encoded) else: bsz, num_intr, _ = intruder_states.size() intr_enc = self.intruder_encoder( intruder_states.view(-1, 6) ).view(bsz, num_intr, -1) attn_output, _ = self.attention( own_encoded.unsqueeze(1), intr_enc, intr_enc ) intruder_feature = attn_output.squeeze(1) combined = torch.cat([own_encoded, intruder_feature], dim=1) return self.output_layer(combined) class DQNAgent: """Double DQN智能体""" def __init__(self, config: Config): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.policy_net = CASNetwork(config).to(self.device) self.target_net = CASNetwork(config).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=config.learning_rate) self.memory = deque(maxlen=config.memory_size) self.epsilon = config.epsilon_start def pad_intruder_states(self, intruder_states: torch.Tensor, max_intruders: int) -> torch.Tensor: batch_size = 1 if len(intruder_states.shape) == 2 else intruder_states.shape[0] current_intruders = intruder_states.shape[-2] if current_intruders < max_intruders: padding = torch.zeros( (batch_size, max_intruders - current_intruders, 6), device=intruder_states.device ) return torch.cat([intruder_states, padding], dim=-2) return intruder_states[:, :max_intruders, :] def select_action(self, own_state: torch.Tensor, intruder_states: torch.Tensor) -> int: if random.random() < self.epsilon: return random.randrange(len(ACTIONS)) with torch.no_grad(): own_state = own_state.unsqueeze(0).to(self.device) intruder_states = intruder_states.unsqueeze(0).to(self.device) intruder_states = self.pad_intruder_states(intruder_states, self.config.max_intruders) q_values = self.policy_net(own_state, intruder_states) return q_values.argmax(dim=1).item() def store_transition(self, s_own, s_intr, action, reward, s2_own, s2_intr, done): self.memory.append((s_own, s_intr, action, reward, s2_own, s2_intr, done)) def train_step(self): if len(self.memory) < self.config.batch_size: return None batch = random.sample(self.memory, self.config.batch_size) s_own, s_intr, acts, rws, s2_own, s2_intr, dns = zip(*batch) s_own = torch.stack(s_own).to(self.device) acts = torch.tensor(acts, dtype=torch.long).to(self.device) rws = torch.tensor(rws, dtype=torch.float).to(self.device) s2_own = torch.stack(s2_own).to(self.device) dns = torch.tensor(dns, dtype=torch.float).to(self.device) max_intruders = max(x.size(0) for x in s_intr) max_intruders = min(max_intruders, self.config.max_intruders) intr_batch = [] for s_i in s_intr: padded = self.pad_intruder_states(s_i.unsqueeze(0), max_intruders).squeeze(0) intr_batch.append(padded) intr_batch = torch.stack(intr_batch).to(self.device) intr2_batch = [] for s2_i in s2_intr: padded2 = self.pad_intruder_states(s2_i.unsqueeze(0), max_intruders).squeeze(0) intr2_batch.append(padded2) intr2_batch = torch.stack(intr2_batch).to(self.device) q_curr = self.policy_net(s_own, intr_batch) q_curr_a = q_curr.gather(1, acts.unsqueeze(1)) with torch.no_grad(): next_actions = self.policy_net(s2_own, intr2_batch).argmax(1) q_next = self.target_net(s2_own, intr2_batch) q_next_a = q_next.gather(1, next_actions.unsqueeze(1)).squeeze(1) target_q = rws + (1 - dns) * self.config.gamma * q_next_a loss_fn = nn.SmoothL1Loss() loss = loss_fn(q_curr_a.squeeze(), target_q) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() self.epsilon = max(self.config.epsilon_end, self.epsilon * self.config.epsilon_decay) return loss.item() def update_target_network(self): self.target_net.load_state_dict(self.policy_net.state_dict()) def soft_update_target_network(self, tau=0.001): for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()): tp.data.copy_(tau * pp.data + (1.0 - tau) * tp.data) def train_cas(config: Config, num_episodes: int = 5000): env = MultiAircraftEnv(config) agent = DQNAgent(config) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_training_{current_time}') writer = SummaryWriter(log_dir) episode_rewards = [] nmac_counts = [] for episode in range(num_episodes): own_state, intr_states = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 episode_loss = 0.0 num_training_steps = 0 initial_intruders = len(env.intruder_states) # 记录本回合所有状态(可选) episode_data = [] while True: current_state = { "own_state": { "altitude": env.own_state['altitude'], "v_speed": env.own_state['v_speed_o'], "action": action_number_map[env.own_state['prev_action']] }, "intruders": [{ "id": intr['id'], "x_range": intr['x_range'], "y_range": intr['y_range'], "rel_h": intr['rel_h'], "v_speed": intr['v_speed'], "horizontal_speed": intr['horizontal_speed'], "angle": intr['angle'], "escaped": intr['escaped'], "turn_rate": intr['turn_rate'] } for intr in env.intruder_states] } episode_data.append(current_state) action = agent.select_action(own_state, intr_states) (next_own, next_intr), reward, done = env.step(action) agent.store_transition(own_state, intr_states, action, reward, next_own, next_intr, done) if len(agent.memory) >= config.batch_size: loss = agent.train_step() if loss is not None: episode_loss += loss num_training_steps += 1 ep_reward += reward steps += 1 if reward <= -100.0: nmac_count += 1 own_state, intr_states = next_own, next_intr # 软更新 target 网络 agent.soft_update_target_network(tau=config.tau) if done: # 可选:将episode数据保存到文件 filename = f'episode_data_{episode}.txt' record_episode(env, episode_data, filename) break episode_rewards.append(ep_reward) nmac_counts.append(nmac_count) # 日志记录 writer.add_scalar('1.Rewards/Episode_Reward', ep_reward, episode) if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) writer.add_scalar('1.Rewards/Average_Reward', avg_r, episode) writer.add_scalar('2.Safety/NMAC_Count', nmac_count, episode) if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) writer.add_scalar('2.Safety/Average_NMAC', avg_n, episode) writer.add_scalar('3.Training/Steps', steps, episode) writer.add_scalar('3.Training/Epsilon', agent.epsilon, episode) if num_training_steps > 0: writer.add_scalar('3.Training/Average_Loss', episode_loss / num_training_steps, episode) stats = env.get_statistics() writer.add_scalar('6.Intruders/Initial_Count', initial_intruders, episode) writer.add_scalar('6.Intruders/Active_Count', stats['active_intruders'], episode) writer.add_scalar('6.Intruders/Escaped_Count', stats['escaped_intruders'], episode) writer.add_scalar('7.Own_State/Vertical_Speed', env.own_state['v_speed_o'], episode) writer.add_scalar('7.Own_State/Action_Changes', stats['action_changes'], episode) print(f"\nEpisode {episode + 1}/{num_episodes}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}, Eps: {agent.epsilon:.3f}") print(f"Initial Intruders: {initial_intruders}") if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) print(f"AvgReward(100): {avg_r:.2f}") if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) print(f"AvgNMAC(100): {avg_n:.3f}") if num_training_steps > 0: print(f"AvgLoss: {episode_loss / num_training_steps:.6f}") print(f"Active Intruders: {stats['active_intruders']}, Escaped Intruders: {stats['escaped_intruders']}") print(f"Action Changes: {stats['action_changes']}") print("-" * 50) writer.close() return agent, episode_rewards, nmac_counts def evaluate_cas(agent: DQNAgent, config: Config, num_episodes: int = 100): env = MultiAircraftEnv(config) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_evaluation_{current_time}') writer = SummaryWriter(log_dir) total_rewards = [] nmac_counts = [] step_counts = [] old_eps = agent.epsilon agent.epsilon = 0.0 for ep in range(num_episodes): own_state, intr_states = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 action_counts = torch.zeros(len(ACTIONS)) while True: action = agent.select_action(own_state, intr_states) action_counts[action] += 1 (next_own, next_intr), reward, done = env.step(action) ep_reward += reward steps += 1 if reward <= -100.0: nmac_count += 1 own_state, intr_states = next_own, next_intr if done: break writer.add_scalar('Evaluation/Episode_Reward', ep_reward, ep) writer.add_scalar('Evaluation/NMAC_Count', nmac_count, ep) writer.add_scalar('Evaluation/Steps', steps, ep) action_probs = action_counts / steps for name, idx in ACTIONS.items(): writer.add_scalar(f'Evaluation/Action_{name}', action_probs[idx].item(), ep) if hasattr(env, 'get_statistics'): st = env.get_statistics() writer.add_scalar('Evaluation/Active_Intruders', st['active_intruders'], ep) writer.add_scalar('Evaluation/Escaped_Intruders', st['escaped_intruders'], ep) total_rewards.append(ep_reward) nmac_counts.append(nmac_count) step_counts.append(steps) print(f"\nEvaluation Episode {ep + 1}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}") avg_r = sum(total_rewards) / len(total_rewards) avg_n = sum(nmac_counts) / len(nmac_counts) avg_stp = sum(step_counts) / len(step_counts) writer.add_hparams( {'eval_episodes': num_episodes}, { 'hparam/avg_reward': avg_r, 'hparam/avg_nmac': avg_n, 'hparam/avg_steps': avg_stp } ) writer.close() agent.epsilon = old_eps print("\nEvaluation Results:") print(f"Avg Reward: {avg_r:.2f}") print(f"Avg NMAC: {avg_n:.3f}") print(f"Avg Steps: {avg_stp:.1f}") return avg_r, avg_n, avg_stp if __name__ == "__main__": config = Config() print("Starting training...") agent, ep_rewards, nmacs = train_cas(config, num_episodes=5000) # 保存模型 save_dir = 'models' os.makedirs(save_dir, exist_ok=True) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') model_path = os.path.join(save_dir, f'cas_model_{current_time}.pth') torch.save({ 'policy_net_state_dict': agent.policy_net.state_dict(), 'target_net_state_dict': agent.target_net.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict(), 'config': config, }, model_path) # 评估 print("\nStarting evaluation...") evaluate_cas(agent, config, num_episodes=1000)
环境复杂度改进:
MultiAircraftEnv.step()
中新增了对入侵机 turn_rate
、horizontal_speed
的随机变化逻辑:
pythonif random.random() < self.config.heading_change_prob: intruder['turn_rate'] = np.random.uniform(*self.config.turn_rate_range) if random.random() < self.config.speed_change_prob: horizontal_acc = np.random.uniform(-self.config.max_horizontal_acc, self.config.max_horizontal_acc) intruder['horizontal_speed'] += horizontal_acc * self.config.time_step intruder['horizontal_speed'] = np.clip( intruder['horizontal_speed'], self.config.min_hspeed, self.config.max_hspeed ) intruder['angle'] += intruder['turn_rate'] * self.config.time_step
奖励函数改进:
self.config.survival_reward
(例如 0.01)的小正奖励,鼓励在没有冲突的情况下生存越久越好。total_reward -= self.config.large_vs_penalty_factor * abs(self.own_state['v_speed_o'])
,在一定程度上限制极端爬升/下降。该惩罚项可根据具体实验需求进行调节。其余说明:
max_steps_per_episode
(示例里为 300 步)达到后强制 done
。请针对以下两点对模型进行改进,给出改进后的完整代码和相应的说明
1.引入真实动力学模型,让本机在执行爬升/下降时耦合到水平速度、推力限制或爬升率限制等,更贴近真实飞行状态。
2.扩展对多回合持续流量:飞机可随机进场或离场,而不只是初始时刻一次性生成。
下面给出一个示例性完整代码,将在原有代码基础上新增或修改以下功能:
引入简化的本机飞行动力学:
扩展多回合持续流量:
注意:以下代码仅用于展示思路,在真实应用中可根据需求增加更严谨的推力/阻力/升力等公式,并调试各类超参数(如
spawn_prob
,speed_coupling_factor
等)。
pythonimport math import numpy as np import torch import torch.nn as nn import torch.optim as optim import random from collections import deque from typing import List, Tuple, Dict from torch.utils.tensorboard import SummaryWriter import datetime import os # ====================== # 常量定义 # ====================== NMI2FEET = 6076.11549 KT2FPS = 1.6878098571 G = 9.80665 FPS2KT = 1 / KT2FPS FPM2FPS = 1 / 60 # 可以加一个更宽松的冲突检测阈值,用来施加软惩罚 CONFLICT_THRESHOLD_H = 2000.0 # 水平距离小于2000 ft,就认为有冲突趋势 CONFLICT_THRESHOLD_V = 600.0 # 垂直距离小于600 ft,就认为有冲突趋势 DEFAULT_OWN_AIR_SPEED = 200 * KT2FPS # 本机初始水平速度(200节) # --------------------- # 动作空间 # --------------------- ACTIONS = { 'COC': 0, # Clear of Conflict 'DNC2000': 1, 'DND2000': 2, 'DNC1000': 3, 'DND1000': 4, 'DNC500': 5, 'DND500': 6, 'DNC': 7, 'DND': 8, 'MCL': 9, 'MDES': 10, 'CL1500': 11, 'DES1500': 12, 'SCL1500': 13, 'SDES1500': 14, 'SCL2500': 15, 'SDES2500': 16 } action_number_map = [ 'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND', 'MCL', 'MDES', 'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500' ] # --------------------- # 每个动作的垂直速率加速度表(仍参考原逻辑) # --------------------- ACCELS = [0] + [1 / 4] * 12 + [1 / 3] * 4 # --------------------- # 垂直速率限制 # --------------------- vertical_rate_limit = { 'COC': [-100.0, 100.0], 'DNC2000': [-100.0, 2000.0 / 60], 'DND2000': [-2000.0 / 60, 100.0], 'DNC1000': [-100.0, 1000.0 / 60], 'DND1000': [-1000.0 / 60, 100.0], 'DNC500': [-100.0, 500.0 / 60], 'DND500': [-500.0 / 60, 100.0], 'DNC': [-100.0, 0.0], 'DND': [0.0, 100.0], 'MCL': [0.0, 100.0], 'MDES': [-100.0, 0.0], 'CL1500': [1500.0 / 60, 100.0], 'DES1500': [-100.0, -1500.0 / 60], 'SCL1500': [1500.0 / 60, 100.0], 'SDES1500': [-100.0, -1500.0 / 60], 'SCL2500': [2500.0 / 60, 100.0], 'SDES2500': [-100.0, -2500.0 / 60] } # =================================================== # TCAS动作信息 # =================================================== action_info = { 'COC': {'direction': 0, 'level': 0}, # 预防性RA 'DNC2000': {'direction': 0, 'level': 1}, 'DND2000': {'direction': 0, 'level': 1}, 'DNC1000': {'direction': 0, 'level': 1}, 'DND1000': {'direction': 0, 'level': 1}, 'DNC500': {'direction': 0, 'level': 1}, 'DND500': {'direction': 0, 'level': 1}, 'DNC': {'direction': 0, 'level': 1}, 'DND': {'direction': 0, 'level': 1}, 'MCL': {'direction': +1, 'level': 1}, 'MDES': {'direction': -1, 'level': 1}, # 纠正性RA 'CL1500': {'direction': +1, 'level': 2}, 'DES1500': {'direction': -1, 'level': 2}, # 更强纠正RA 'SCL1500': {'direction': +1, 'level': 3}, 'SDES1500': {'direction': -1, 'level': 3}, 'SCL2500': {'direction': +1, 'level': 3}, 'SDES2500': {'direction': -1, 'level': 3}, } # =================================================== # 允许的动作转移表 # =================================================== allowed_tcas_transitions = { 'COC': [ 'COC', 'DNC2000', 'DND2000', 'DNC1000', 'DND1000', 'DNC500', 'DND500', 'DNC', 'DND', 'MCL', 'MDES', 'CL1500', 'DES1500', 'SCL1500', 'SDES1500', 'SCL2500', 'SDES2500' ], 'DNC': ['DNC', 'DNC1000', 'DNC2000', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND': ['DND', 'DND1000', 'DND2000', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC2000': ['DNC2000', 'DNC1000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND2000': ['DND2000', 'DND1000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC1000': ['DNC1000', 'DNC2000', 'DNC', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND1000': ['DND1000', 'DND2000', 'DND', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'DNC500': ['DNC500', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'DND500': ['DND500', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'MCL': ['MCL', 'CL1500', 'SCL1500', 'SCL2500', 'COC'], 'MDES': ['MDES', 'DES1500', 'SDES1500', 'SDES2500', 'COC'], 'CL1500': ['CL1500', 'SCL1500', 'SCL2500', 'COC', 'MCL', 'DNC', 'DNC2000', 'DNC1000'], 'DES1500': ['DES1500', 'SDES1500', 'SDES2500', 'COC', 'MDES', 'DND', 'DND2000', 'DND1000'], 'SCL1500': ['SCL1500', 'SCL2500', 'CL1500', 'MCL', 'COC'], 'SDES1500': ['SDES1500', 'SDES2500', 'DES1500', 'MDES', 'COC'], 'SCL2500': ['SCL2500', 'SCL1500', 'CL1500', 'MCL', 'COC'], 'SDES2500': ['SDES2500', 'SDES1500', 'DES1500', 'MDES', 'COC'], } class Config: def __init__(self): # 时间步长 self.time_step = 1.0 # 本机性能限制 self.max_climb_rate = 6000.0 / 60.0 self.max_descent_rate = -6000.0 / 60.0 # 真实动力学(简化)相关 self.speed_coupling_factor = 0.001 # 爬升/下降耦合到水平速度的因子 self.min_own_hspeed = 150 * KT2FPS # 本机水平速度下限(150节) self.max_own_hspeed = 350 * KT2FPS # 本机水平速度上限(350节) self.drag_factor = 0.0 # 可以加一个简单的阻力影响(示例中可设0或小值) # 加速度限制 self.max_accel_up = 0.4 * G self.max_accel_down = -0.3 * G # 冲突阈值 self.conflict_threshold_h = CONFLICT_THRESHOLD_H self.conflict_threshold_v = CONFLICT_THRESHOLD_V # 奖励系数 self.conflict_penalty = 1.0 self.ttc_penalty_factor = 0.5 self.distance_weight = 0.0001 # 非法转移/反向/加强/减弱惩罚 self.illegal_transition_penalty = 5.0 self.sense_reversal_penalty = 0.01 self.strengthen_penalty = 0.005 self.weaken_penalty = 0.001 # DQN相关 self.target_update_method = 'soft' self.tau = 0.001 # NMAC阈值 self.nmac_threshold_v = 100 self.nmac_threshold_h = 500 # 飞机数量范围 self.min_intruders = 5 self.max_intruders = 10 # 多回合持续流量相关 self.world_size_x = 8 * NMI2FEET self.world_size_y = 8 * NMI2FEET self.spawn_prob = 0.05 # 每步有 5% 概率产生新机 # 网络参数 self.hidden_dim = 128 self.num_heads = 4 # 训练参数 self.batch_size = 32 self.gamma = 0.99 self.learning_rate = 1e-4 self.memory_size = 10000 self.target_update_freq = 10 self.epsilon_start = 1.0 self.epsilon_end = 0.05 self.epsilon_decay = 0.995 # 归一化参数 self.normalization = { 'x_range': 5 * NMI2FEET, 'y_range': 3 * NMI2FEET, 'rel_h': 3000.0, 'v_speed': 100.0, 'horizontal_speed': 600 * KT2FPS, 'angle': 2 * np.pi, 'v_speed_o': 100.0 } # 附加奖励 self.survival_reward = 0.01 self.large_vs_penalty_factor = 0.0001 # 回合步数限制 self.max_steps_per_episode = 300 def compute_tau_horizontal(intruder: dict, own_hspeed: float) -> float: rel_x = intruder['x_range'] rel_y = intruder['y_range'] dist_h = math.sqrt(rel_x ** 2 + rel_y ** 2) if dist_h <= 1e-5: return math.inf vx_int = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy_int = intruder['horizontal_speed'] * math.sin(intruder['angle']) rel_vx = vx_int - own_hspeed rel_vy = vy_int closing_speed = -((rel_x * rel_vx + rel_y * rel_vy) / dist_h) if closing_speed <= 1e-5: return math.inf tau_h = dist_h / closing_speed return tau_h class MultiAircraftEnv: """多飞机环境 + 本机简易动力学 + 多回合持续流量""" def __init__(self, config: Config): self.config = config self.escape_threshold = 6 * NMI2FEET self.next_intruder_id = 1 self.action_change_count = 0 self.current_step = 0 self.intruder_states = [] self.own_state = {} def reset(self) -> Tuple[torch.Tensor, torch.Tensor]: self.action_change_count = 0 self.current_step = 0 self.next_intruder_id = 1 self.intruder_states = [] # 初始化本机状态 # 加入一个水平速度维度 self.own_state = { 'altitude': 15000.0, 'v_speed_o': 0.0, # 垂直速度 'h_speed_o': DEFAULT_OWN_AIR_SPEED, # 水平速度(ft/s) 'prev_action': 0 } # 初始时生成 min_intruders ~ max_intruders 架入侵机 num_intruders = random.randint(self.config.min_intruders, self.config.max_intruders) for _ in range(num_intruders): self.intruder_states.append(self.initialize_intruder()) return self.get_state_tensors() def initialize_intruder(self) -> Dict: # 在边界范围内随机生成 # 这里随机在边界外一段距离,以模拟进场 x_range = np.random.uniform(-self.config.world_size_x, self.config.world_size_x) y_range = np.random.uniform(-self.config.world_size_y, self.config.world_size_y) # 生成随机角度和速度 angle = np.random.uniform(0, 2 * np.pi) horizontal_speed = np.random.uniform(200, 600) * KT2FPS # 确保它是“往场内”方向运动(简化逻辑:朝向本机 0,0 也可以) # 也可随机,但为演示,这里让它往(0,0)逼近 dx = -x_range dy = -y_range desired_angle = math.atan2(dy, dx) # 在 desired_angle 附近增加一些随机扰动 angle = desired_angle + np.random.uniform(-0.3, 0.3) turn_rate = np.random.uniform(-0.02, 0.02) # 可微调 rel_h = np.random.uniform(-3000, 3000) v_speed = np.random.uniform(-100, 100) intruder_state = { 'x_range': x_range, 'y_range': y_range, 'rel_h': rel_h, 'v_speed': v_speed, 'horizontal_speed': horizontal_speed, 'angle': angle, 'turn_rate': turn_rate, 'escaped': False, 'id': self.next_intruder_id } intruder_state['horizontal_range'] = math.sqrt(x_range ** 2 + y_range ** 2) self.next_intruder_id += 1 return intruder_state def get_state_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: # 本机状态多加一个 h_speed_o (水平速度) 归一化 own_tensor = torch.tensor([ self.own_state['v_speed_o'] / self.config.normalization['v_speed_o'], self.own_state['prev_action'], self.own_state['h_speed_o'] / self.config.normalization['horizontal_speed'] # 新增项 ], dtype=torch.float32) active_intruders = [i for i in self.intruder_states if not i['escaped']] if not active_intruders: intruder_tensor = torch.zeros((0, 6), dtype=torch.float32) else: intruder_tensor = torch.tensor([ [ it['rel_h'] / self.config.normalization['rel_h'], it['v_speed'] / self.config.normalization['v_speed'], it['x_range'] / self.config.normalization['x_range'], it['y_range'] / self.config.normalization['y_range'], it['horizontal_speed'] / self.config.normalization['horizontal_speed'], (it['angle'] % (2 * np.pi)) / self.config.normalization['angle'] ] for it in active_intruders ], dtype=torch.float32) return own_tensor, intruder_tensor def clamp_own_vertical_speed(self, vs: float) -> float: return max(self.config.max_descent_rate, min(vs, self.config.max_climb_rate)) def clamp_own_horizontal_speed(self, hs: float) -> float: return max(self.config.min_own_hspeed, min(hs, self.config.max_own_hspeed)) def spawn_new_intruder(self): """随机生成新入侵机""" new_intruder = self.initialize_intruder() self.intruder_states.append(new_intruder) def step(self, action: int) -> Tuple[Tuple[torch.Tensor, torch.Tensor], float, bool]: self.current_step += 1 done = False reward = 0.0 # 1) TCAS动作转移合法性 & 额外惩罚 prev_action = self.own_state['prev_action'] curr_name = action_number_map[action] prev_name = action_number_map[prev_action] if curr_name not in allowed_tcas_transitions[prev_name]: reward -= self.config.illegal_transition_penalty dir_prev = action_info[prev_name]['direction'] dir_now = action_info[curr_name]['direction'] lvl_prev = action_info[prev_name]['level'] lvl_now = action_info[curr_name]['level'] # 反向处罚 if dir_prev != 0 and dir_now != 0 and (dir_prev * dir_now < 0): reward -= self.config.sense_reversal_penalty # 加强/减弱惩罚 if dir_prev == dir_now and dir_now != 0: if lvl_now > lvl_prev: reward -= self.config.strengthen_penalty elif lvl_now < lvl_prev: reward -= self.config.weaken_penalty # 若动作发生变化, 计数+1 if action != prev_action: self.action_change_count += 1 # 2) 更新本机动力学 accel_v = ACCELS[action] * G # 限制加速度上下限 accel_v = max(self.config.max_accel_down, min(accel_v, self.config.max_accel_up)) prev_v_speed = self.own_state['v_speed_o'] # 计算新的垂直速度 next_v_speed = prev_v_speed + accel_v * self.config.time_step next_v_speed = self.clamp_own_vertical_speed(next_v_speed) # 根据爬升/下降耦合改变水平速度(简化) # 假设:爬升越快,水平速度衰减越多;下降越快,衰减越少或稍微增加 # 这里可以根据实际推力模型/能量守恒等更精确地写 prev_h_speed = self.own_state['h_speed_o'] # 一个示例性的耦合:delta_h = - speed_coupling_factor * (next_v_speed) # 也可根据正负添加一些系数区分 delta_h = - self.config.speed_coupling_factor * (next_v_speed) # 也可添加阻力 drag = self.config.drag_factor * (prev_h_speed ** 2) next_h_speed = prev_h_speed + delta_h - drag * self.config.time_step next_h_speed = self.clamp_own_horizontal_speed(next_h_speed) # 更新状态 avg_v_speed = 0.5 * (prev_v_speed + next_v_speed) self.own_state['altitude'] += avg_v_speed * self.config.time_step self.own_state['v_speed_o'] = next_v_speed self.own_state['h_speed_o'] = next_h_speed self.own_state['prev_action'] = action # 额外惩罚:大垂直速度 reward -= self.config.large_vs_penalty_factor * abs(next_v_speed) # 入侵机动态更新 nmac_occurred = False for intruder in self.intruder_states: if intruder.get('escaped', False): continue # 更新入侵机 old_range = intruder['horizontal_range'] # 更新其垂直速度 v_accel = np.random.uniform(self.config.max_accel_down, self.config.max_accel_up) prev_vi = intruder['v_speed'] next_vi = prev_vi + v_accel * self.config.time_step # 限制在 [-100, 100] 做示例 next_vi = np.clip(next_vi, -100, 100) intruder['v_speed'] = next_vi # 水平角度 & 速度 intruder['angle'] += intruder['turn_rate'] * self.config.time_step vx = intruder['horizontal_speed'] * math.cos(intruder['angle']) vy = intruder['horizontal_speed'] * math.sin(intruder['angle']) # 本机水平运动: next_h_speed ft/s (假设朝X正方向) intruder['x_range'] += (vx - next_h_speed) * self.config.time_step intruder['y_range'] += vy * self.config.time_step # 垂直相对距离更新 avg_vi_intr = 0.5 * (prev_vi + next_vi) intruder['rel_h'] += (avg_vi_intr - avg_v_speed) * self.config.time_step # 计算新的水平距离 intruder['horizontal_range'] = math.sqrt( intruder['x_range']**2 + intruder['y_range']**2 ) # NMAC检测 if abs(intruder['rel_h']) < self.config.nmac_threshold_v and intruder['horizontal_range'] < self.config.nmac_threshold_h: nmac_occurred = True reward -= 10.0 # 软冲突罚分 if abs(intruder['rel_h']) < self.config.conflict_threshold_v and intruder['horizontal_range'] < self.config.conflict_threshold_h: dist_penalty = 0.0 if intruder['horizontal_range'] > 0: dist_penalty += self.config.distance_weight / intruder['horizontal_range'] if abs(intruder['rel_h']) > 0: dist_penalty += self.config.distance_weight / abs(intruder['rel_h']) reward -= dist_penalty # 计算TTC ttc = compute_tau_horizontal(intruder, own_hspeed=next_h_speed) if ttc < 9999.0: reward -= self.config.ttc_penalty_factor / (ttc + 1e-5) # 如果入侵机相对水平距离增大且超过阈值 或 tau_h = inf,则视为逃脱 if (intruder['horizontal_range'] > self.escape_threshold and intruder['horizontal_range'] > old_range) \ or intruder['horizontal_range'] > 2 * self.config.world_size_x: # 也可直接用世界边界 intruder['escaped'] = True reward += 0.1 # 如果越界(离场) if (abs(intruder['x_range']) > self.config.world_size_x or abs(intruder['y_range']) > self.config.world_size_y): intruder['escaped'] = True # 如果没有NMAC,每步生存给一点正奖励 if not nmac_occurred: reward += self.config.survival_reward # 随机生成新入侵机 if random.random() < self.config.spawn_prob: self.spawn_new_intruder() # 判断是否结束 active_intruders = [i for i in self.intruder_states if not i['escaped']] if nmac_occurred or len(active_intruders) == 0: done = True if self.current_step >= self.config.max_steps_per_episode: done = True return self.get_state_tensors(), reward, done def get_statistics(self) -> Dict: active_count = len([i for i in self.intruder_states if not i['escaped']]) escaped_count = len(self.intruder_states) - active_count return { 'active_intruders': active_count, 'escaped_intruders': escaped_count, 'total_intruders': len(self.intruder_states), 'action_changes': self.action_change_count } def record_episode(env, episode_data, filename): with open(filename, 'w') as f: config_info = { "altitude": env.own_state['altitude'], "v_speed_o": env.own_state['v_speed_o'], "h_speed_o": env.own_state['h_speed_o'], "time_step": env.config.time_step } f.write("# CONFIG\n") f.write(str(config_info) + "\n") f.write("# DATA\n") for s in episode_data: f.write(str(s) + "\n") class CASNetwork(nn.Module): """使用注意力机制处理可变数量入侵机的网络""" def __init__(self, config: Config): super().__init__() self.config = config # 本机状态编码器(现在本机多了一维 h_speed_o => 3维输入) self.own_encoder = nn.Sequential( nn.Linear(3, config.hidden_dim), nn.ReLU() ) # 入侵机状态编码器(仍是6维) self.intruder_encoder = nn.Sequential( nn.Linear(6, config.hidden_dim), nn.ReLU() ) # 注意力层 self.attention = nn.MultiheadAttention( config.hidden_dim, config.num_heads, batch_first=True ) # 输出层 self.output_layer = nn.Sequential( nn.Linear(config.hidden_dim * 2, config.hidden_dim), nn.ReLU(), nn.Linear(config.hidden_dim, len(ACTIONS)) ) def forward(self, own_state, intruder_states): own_encoded = self.own_encoder(own_state) # [B, hidden_dim] if intruder_states.size(1) == 0: # 若无入侵机,设为0 intruder_feature = torch.zeros_like(own_encoded) else: bsz, num_intr, _ = intruder_states.size() intr_enc = self.intruder_encoder( intruder_states.view(-1, 6) ).view(bsz, num_intr, -1) attn_output, _ = self.attention( own_encoded.unsqueeze(1), intr_enc, intr_enc ) intruder_feature = attn_output.squeeze(1) combined = torch.cat([own_encoded, intruder_feature], dim=1) return self.output_layer(combined) class DQNAgent: """Double DQN""" def __init__(self, config: Config): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.policy_net = CASNetwork(config).to(self.device) self.target_net = CASNetwork(config).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) self.target_net.eval() self.optimizer = optim.Adam(self.policy_net.parameters(), lr=config.learning_rate) self.memory = deque(maxlen=config.memory_size) self.epsilon = config.epsilon_start def pad_intruder_states(self, intruder_states: torch.Tensor, max_intruders: int) -> torch.Tensor: batch_size = 1 if len(intruder_states.shape) == 2 else intruder_states.shape[0] current_intruders = intruder_states.shape[-2] if current_intruders < max_intruders: padding = torch.zeros( (batch_size, max_intruders - current_intruders, 6), device=intruder_states.device ) return torch.cat([intruder_states, padding], dim=-2) return intruder_states[:, :max_intruders, :] def select_action(self, own_state: torch.Tensor, intruder_states: torch.Tensor) -> int: # epsilon-greedy if random.random() < self.epsilon: return random.randrange(len(ACTIONS)) with torch.no_grad(): own_state = own_state.unsqueeze(0).to(self.device) intruder_states = intruder_states.unsqueeze(0).to(self.device) intruder_states = self.pad_intruder_states(intruder_states, self.config.max_intruders) q_values = self.policy_net(own_state, intruder_states) return q_values.argmax(dim=1).item() def store_transition(self, s_own, s_intr, action, reward, s2_own, s2_intr, done): self.memory.append((s_own, s_intr, action, reward, s2_own, s2_intr, done)) def train_step(self): if len(self.memory) < self.config.batch_size: return None batch = random.sample(self.memory, self.config.batch_size) s_own, s_intr, acts, rws, s2_own, s2_intr, dns = zip(*batch) s_own = torch.stack(s_own).to(self.device) acts = torch.tensor(acts, dtype=torch.long).to(self.device) rws = torch.tensor(rws, dtype=torch.float).to(self.device) s2_own = torch.stack(s2_own).to(self.device) dns = torch.tensor(dns, dtype=torch.float).to(self.device) # 处理入侵机信息 max_intruders = max(x.size(0) for x in s_intr) max_intruders = min(max_intruders, self.config.max_intruders) intr_batch = [] for s_i in s_intr: padded = self.pad_intruder_states(s_i.unsqueeze(0), max_intruders).squeeze(0) intr_batch.append(padded) intr_batch = torch.stack(intr_batch).to(self.device) intr2_batch = [] for s2_i in s2_intr: padded2 = self.pad_intruder_states(s2_i.unsqueeze(0), max_intruders).squeeze(0) intr2_batch.append(padded2) intr2_batch = torch.stack(intr2_batch).to(self.device) q_curr = self.policy_net(s_own, intr_batch) # [B, #actions] q_curr_a = q_curr.gather(1, acts.unsqueeze(1)) with torch.no_grad(): next_actions = self.policy_net(s2_own, intr2_batch).argmax(1) q_next = self.target_net(s2_own, intr2_batch) q_next_a = q_next.gather(1, next_actions.unsqueeze(1)).squeeze(1) target_q = rws + (1 - dns) * self.config.gamma * q_next_a loss_fn = nn.SmoothL1Loss() loss = loss_fn(q_curr_a.squeeze(), target_q) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() self.epsilon = max(self.config.epsilon_end, self.epsilon * self.config.epsilon_decay) return loss.item() def soft_update_target_network(self, tau=0.001): for tp, pp in zip(self.target_net.parameters(), self.policy_net.parameters()): tp.data.copy_(tau * pp.data + (1.0 - tau) * tp.data) def train_cas(config: Config, num_episodes: int = 500): env = MultiAircraftEnv(config) agent = DQNAgent(config) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_training_{current_time}') writer = SummaryWriter(log_dir) episode_rewards = [] nmac_counts = [] for episode in range(num_episodes): own_state, intr_state = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 episode_loss = 0.0 num_training_steps = 0 episode_data = [] while True: # 记录一下当前状态 current_info = { "own_state": { "altitude": env.own_state['altitude'], "v_speed_o": env.own_state['v_speed_o'], "h_speed_o": env.own_state['h_speed_o'], "action": action_number_map[env.own_state['prev_action']] }, "intruders": [{ "id": x['id'], "x_range": x['x_range'], "y_range": x['y_range'], "rel_h": x['rel_h'], "v_speed": x['v_speed'], "horizontal_speed": x['horizontal_speed'], "angle": x['angle'], "escaped": x['escaped'] } for x in env.intruder_states] } episode_data.append(current_info) action = agent.select_action(own_state, intr_state) (next_own, next_intr), reward, done = env.step(action) agent.store_transition(own_state, intr_state, action, reward, next_own, next_intr, done) ep_reward += reward steps += 1 # 训练 if len(agent.memory) >= config.batch_size: loss = agent.train_step() if loss is not None: episode_loss += loss num_training_steps += 1 # 软更新 agent.soft_update_target_network(tau=config.tau) if reward <= -100: nmac_count += 1 own_state, intr_state = next_own, next_intr if done: record_episode(env, episode_data, f'episode_{episode}.txt') break episode_rewards.append(ep_reward) nmac_counts.append(nmac_count) writer.add_scalar('1.Rewards/Episode_Reward', ep_reward, episode) if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) writer.add_scalar('1.Rewards/Average_Reward', avg_r, episode) writer.add_scalar('2.Safety/NMAC_Count', nmac_count, episode) if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) writer.add_scalar('2.Safety/Average_NMAC', avg_n, episode) writer.add_scalar('3.Training/Steps', steps, episode) writer.add_scalar('3.Training/Epsilon', agent.epsilon, episode) if num_training_steps > 0: writer.add_scalar('3.Training/Average_Loss', episode_loss / num_training_steps, episode) stats = env.get_statistics() writer.add_scalar('6.Intruders/Active_Count', stats['active_intruders'], episode) writer.add_scalar('6.Intruders/Escaped_Count', stats['escaped_intruders'], episode) writer.add_scalar('6.Intruders/Total_Intruders', stats['total_intruders'], episode) writer.add_scalar('7.Own_State/Action_Changes', stats['action_changes'], episode) print(f"\nEpisode {episode + 1}/{num_episodes}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}, Eps: {agent.epsilon:.3f}") if len(episode_rewards) > 0: avg_r = sum(episode_rewards[-100:]) / min(len(episode_rewards), 100) print(f"AvgReward(100): {avg_r:.2f}") if len(nmac_counts) > 0: avg_n = sum(nmac_counts[-100:]) / min(len(nmac_counts), 100) print(f"AvgNMAC(100): {avg_n:.3f}") if num_training_steps > 0: print(f"AvgLoss: {episode_loss / num_training_steps:.6f}") print(f"Active Intruders: {stats['active_intruders']}, Escaped: {stats['escaped_intruders']}, " f"Total: {stats['total_intruders']}") print(f"Action Changes: {stats['action_changes']}") print("-"*50) writer.close() return agent, episode_rewards, nmac_counts def evaluate_cas(agent: DQNAgent, config: Config, num_episodes: int = 100): env = MultiAircraftEnv(config) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') log_dir = os.path.join('runs', f'CAS_evaluation_{current_time}') writer = SummaryWriter(log_dir) total_rewards = [] nmac_counts = [] step_counts = [] old_eps = agent.epsilon agent.epsilon = 0.0 for ep in range(num_episodes): own_state, intr_state = env.reset() ep_reward = 0.0 steps = 0 nmac_count = 0 while True: action = agent.select_action(own_state, intr_state) (next_own, next_intr), reward, done = env.step(action) ep_reward += reward steps += 1 if reward <= -100: nmac_count += 1 own_state, intr_state = next_own, next_intr if done: break writer.add_scalar('Evaluation/Episode_Reward', ep_reward, ep) writer.add_scalar('Evaluation/NMAC_Count', nmac_count, ep) writer.add_scalar('Evaluation/Steps', steps, ep) total_rewards.append(ep_reward) nmac_counts.append(nmac_count) step_counts.append(steps) print(f"\nEvaluation Episode {ep + 1}") print(f"Steps: {steps}, Reward: {ep_reward:.2f}, NMAC: {nmac_count}") avg_r = sum(total_rewards) / len(total_rewards) avg_n = sum(nmac_counts) / len(nmac_counts) avg_stp = sum(step_counts) / len(step_counts) writer.add_hparams( {'eval_episodes': num_episodes}, { 'hparam/avg_reward': avg_r, 'hparam/avg_nmac': avg_n, 'hparam/avg_steps': avg_stp } ) writer.close() agent.epsilon = old_eps print("\nEvaluation Results:") print(f"Avg Reward: {avg_r:.2f}") print(f"Avg NMAC: {avg_n:.3f}") print(f"Avg Steps: {avg_stp:.1f}") return avg_r, avg_n, avg_stp if __name__ == "__main__": config = Config() print("Starting training...") agent, ep_rewards, nmacs = train_cas(config, num_episodes=500) save_dir = 'models' os.makedirs(save_dir, exist_ok=True) current_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S') model_path = os.path.join(save_dir, f'cas_model_{current_time}.pth') torch.save({ 'policy_net_state_dict': agent.policy_net.state_dict(), 'target_net_state_dict': agent.target_net.state_dict(), 'optimizer_state_dict': agent.optimizer.state_dict(), 'config': config, }, model_path) print("\nStarting evaluation...") evaluate_cas(agent, config, num_episodes=50)
本机简化动力学
h_speed_o
,每次执行爬升或下降动作时,我们对水平速度施加一个耦合变化 (speed_coupling_factor
),即如果垂直速度升高,则水平速度相应有所减小;若垂直速度下降则减少衰减,甚至可能微增。drag_factor
。在更真实的情况下,可使用完整的推力/升力/重力/阻力模型做数值积分。多回合持续流量
world_size_x
= 8 NMI。spawn_prob
)生成新入侵机,其位置会被随机初始化在边界附近,朝向本机或中心。escaped
并从环境中去除。其余逻辑
soft_update_target_network
)。speed_coupling_factor
、spawn_prob
、max_own_hspeed
、max_steps_per_episode
等,以达到更符合需求的复杂度。我可以把本机和入侵机的绝对坐标存储到状态空间,然后只是用相对位置,相对速度训练模型吗?
可以的,这是一种常见做法。在实际实现时,可以同时维护绝对位置(方便做可视化、记录轨迹、检查是否越界等),但只将相对位置(以及相对速度、相对高度)等真正与冲突检测或回避决策直接相关的变量输入到网络模型中进行训练。这样做的好处包括:
保持网络输入的“平移不变性”
只使用相对位置、相对速度来作为智能体的输入,就不必让模型去学习地图上某个“绝对坐标”本身的意义,避免网络去记忆无关的绝对位置。
在防撞/避障场景中,真正决定碰撞风险的是两机间的相对几何关系(距离、相对角度、相对速度等),而非它们在世界坐标系的具体点位。
便于多场景通用
如果只要两架飞机(或多架飞机)的相对状态特征,就能进行决策,那么即使在不同地区(比如更换了坐标原点),算法依旧可以使用相同的模型进行推断,而不必再次适配。
保留绝对坐标做记录或绘制
在环境内部,你依然可以维护绝对坐标,方便在模拟器中可视化飞机的真实位置、速度,或用于其他需要世界坐标的统计或分析。但在训练时将这些坐标转换成相对状态供网络使用即可。
环境中维护绝对坐标
(own_x, own_y, own_altitude, own_vx, own_vy, ...)
,对入侵机亦类似地维护绝对位置与速度 (intr_x, intr_y, intr_alt, intr_vx, intr_vy, ...)
。计算相对状态
env.step()
返回给智能体之前,将绝对坐标和速度转换为相对值:
网络输入与输出不变
(相对位置,相对速度)
等打包成张量送入模型,而不需要额外让模型看到绝对坐标。**总之,在环境里“存绝对”,在模型里“用相对”**是一种常用的设计模式,既能让环境逻辑保持简明,又能让训练的状态表示具有所需的不变性、减少学习难度。
我的神经网络模型训练的输入变量是什么?是否合理?需要添加别的吗?
在当前(或示例)代码中,神经网络的输入通常包括两大部分:** 本机状态与入侵机状态**。以我们前面给出的示例为例,大致如下:
本机状态(合并成一个 1D 向量):
v_speed_o
:本机的垂直速度prev_action
:上一时刻的动作编号(在离散动作空间时,可以将其作为一个整型数直接放进去,或再做 embedding)h_speed_o
:本机水平速度入侵机状态(对每架入侵机都形成一个向量,再拼成 [N, D] 的矩阵,N 表示入侵机数量,D 表示每架机的维度):
rel_h
:相对高度(入侵机高度 - 本机高度)v_speed
:入侵机的垂直速度x_range
、y_range
:入侵机相对于本机的相对水平坐标(如果只用相对位置,不在网络中放绝对坐标)horizontal_speed
:入侵机在水平面内的速度标量angle
:入侵机的水平面速度方向(相对于某个参考轴,例如相对于 X 轴正方向的弧度)在将这些输入送入网络前,一般会进行归一化或标准化,比如除以某个最大物理量或将其缩放到 [-1, 1] 范围。上面代码里就通过 config.normalization[...]
把每个要素做了一个除法,让数值更易训练。
下面给出一些可能有用、但在很多实现里并不是必需的特征,供你参考:
相对速度分量
horizontal_speed
和 angle
,本机也有 h_speed_o
。在网络里,你也可以直接提供“相对水平速度分量 (vx, vy)”,或者把它们预先计算好 (vx_int - vx_own
, vy_int - vy_own
)。相对航向差
(angle_intruder - angle_own)
或类似的简化差值,而不是给两个绝对角度。Time To Collision (TTC) 或 Closest Point of Approach (CPA)
tau_h
)直接当做网络输入特征之一,让网络显式地知道离碰撞还有多少秒;tau_h
为 inf
时如何处理(可以做一个裁剪或将其值固定为一个大数)。历史信息或速度加速度
其他外部信息
[v_speed_o, prev_action, h_speed_o]
+ 对每个入侵机的 [rel_h, v_speed, x_range, y_range, horizontal_speed, angle]
)已经能覆盖大部分冲突检测/回避必需的信息。