夜光超分损失函数
Tạo vào: 3 tháng 1, 2025
Tạo vào: 3 tháng 1, 2025
“# main.py
import argparse
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pannet import PanNet
from utils.landsat8_dataset import Landsat8Dataset
from utils.misc import get_patch_configs
from utils.pansharpening_evaluator import evaluate
import numpy as np
import random
import json
import matplotlib.pyplot as plt
import argparse
from utils import misc
from utils.loss import MSELoss
def plot_losses(train_losses, val_losses, out_dir):
plt.figure()
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig(os.path.join(out_dir, 'loss.png'))
plt.close()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--alpha', type=float, default=0.001)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--eps', type=float, default=1e-8)
parser.add_argument('--eta', type=float, default=1.0)
parser.add_argument('--weight_decay_rate', type=float, default=1e-4)
parser.add_argument('--amsgrad', action='store_true', help='Use AMSGrad variant of AdamW')
parser.add_argument('--data_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/train/')
parser.add_argument('--out_model_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/out_model/')
parser.add_argument('--no_plot', action='store_false', dest='plot', help='Disable plotting of loss')
parser.add_argument('--random_seed', default=1, type=int,
help='Random seed for choosing cell ids and capture ids (default: 1)')
args = parser.parse_args()
textos.makedirs(args.out_model_dir, exist_ok=True) # Set device device = torch.device('cuda' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu') # Prepare datasets print('Preparing training data...') training_patch_configs, validation_patch_configs, _ = misc.get_patch_configs( args.data_dir, training_data_rate=0.8, validation_data_rate=0.2, test_data_rate=0.0, patch_width=16, patch_height=16, num_patches=1000, random_seed=args.random_seed ) training_dataset = Landsat8Dataset(training_patch_configs) validation_dataset = Landsat8Dataset(validation_patch_configs) print('The number of training patches:', len(training_dataset)) print('The number of validation patches:', len(validation_dataset)) # Data loaders train_loader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2) # Device configuration device = torch.device(f'cuda:{args.gpu}' if args.gpu >=0 and torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # Initialize model model = PanNet(out_channels=1).to(device) # Define loss and optimizer criterion = torch.nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.alpha, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay_rate, amsgrad=args.amsgrad) best_val_loss = float('inf') train_losses = [] val_losses = [] for epoch in range(1, args.epochs +1): model.train() running_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() if (batch_idx +1) % 100 ==0: print(f'Epoch [{epoch}/{args.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.6f}') avg_train_loss = running_loss / len(train_loader) train_losses.append(avg_train_loss) print(f'Epoch [{epoch}/{args.epochs}], Average Training Loss: {avg_train_loss:.6f}') # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) val_losses.append(avg_val_loss) print(f'Epoch [{epoch}/{args.epochs}], Validation Loss: {avg_val_loss:.6f}') # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_model_path = os.path.join(args.out_model_dir, 'best_model.pth') torch.save(model.state_dict(), best_model_path) print(f'Best model saved at epoch {epoch}') # Save final model final_model_path = os.path.join(args.out_model_dir, f'model_epoch-{args.epochs}.pth') torch.save(model.state_dict(), final_model_path) print(f'Final model saved at epoch {args.epochs}') # Plot losses if args.plot: plot_losses(train_losses, val_losses, args.out_model_dir) print(f'Loss plot saved in {args.out_model_dir}/loss.png') print('Training completed.')
if name == 'main':
main()
”“# losses/mse_loss.py
import torch
import torch.nn as nn
class MSELoss(nn.Module):
def init(self, predictor):
super(MSELoss, self).init()
self.predictor = predictor
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()
textdef forward(self, x, t): y = self.predictor(x) loss = self.mse_loss(y, t) mse_spectra_mapping_and_gt = self.mse_loss(x[:,3:,:,:], t) abs_error = self.mae_loss(y, t) # You can log these metrics as needed return loss, mse_spectra_mapping_and_gt, abs_error
”# evaluators/pansharpening_evaluator.py
import torch
import numpy as np
def evaluate(model, data_loader, device):
model.eval()
mse_list = []
with torch.no_grad():
for inputs, targets in data_loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
mse = torch.mean((outputs - targets) ** 2).item()
mse_list.append(mse)
avg_mse = np.mean(mse_list)
return avg_mse
这个模型的目标是通过多光谱遥感影像融合低分辨率夜光 遥感影像,实现夜光超分辨率重建、目标影像targets是高分辨率夜光数据,关于损失函数应该怎么选择呢?请你给出相应建议,并给我完整的修改后代码,并且中文解释
下面的回答会先给出一些中文解释,说明为什么在夜光超分辨率(Nighttime Light Super-Resolution)或融合多光谱数据时,仅仅使用单纯的 MSELoss(均方误差)可能存在的问题,以及在实际应用中可以考虑哪些损失函数。随后会提供一份示例性的完整修改后代码(包含 main.py
、自定义损失函数文件,以及评估脚本),其中演示了如何将损失函数改为更加适合夜光超分辨率重建任务的组合损失(例如 L1 + SSIM)。
MSE 可能会导致过度平滑 (Over-smoothing)
对于遥感图像尤其是夜光数据而言,目标在图像上的分布比较稀疏,MSELoss 往往会造成生成图像的模糊化,导致小目标或亮点被平滑掉。
L1 Loss(MAE)对异常值更鲁棒
相比于 MSELoss(对差值平方),L1Loss(对差值绝对值)对图像中的异常值更具鲁棒性,不会过度惩罚局部大差异,从而在某些遥感场景下,能产生视觉效果更优的结果。
结构相似性 (SSIM) 能更好衡量图像质量
SSIM 比简单的像素差更能体现图像结构的一致性,使用 SSIM 相关的损失可以让模型在注重像素值一致的同时,也保留夜光影像的细节、纹理等结构信息。
可考虑 Perceptual Loss / GAN Loss
如果对重建的夜光图像主观视觉要求更高,可以进一步引入感知损失(Perceptual Loss)或对抗损失(GAN Loss)等。不过这些相对复杂,这里以 L1 + SSIM 为例演示。
可以用一个组合损失来同时兼顾像素级误差和结构相似度。比如,令
其中 是一个系数(0.5~0.9 都可实验),以调节重建的平滑度和保留细节的平衡。
如果不想引入 SSIM,也可纯粹使用 L1Loss 来替换原先的 MSELoss,获得更稳定的梯度、减少过度平滑。
如果想引入对抗训练或感知损失,则需要更复杂的网络结构与训练流程,这里暂不赘述。
下面给出的示例包含三个文件的核心代码,分别是:
main.py
torch.nn.MSELoss()
改为自定义的 NightLightLoss
(L1 + SSIM)。losses/night_light_loss.py
NightLightLoss
类,内部使用 L1 与 SSIM 的组合损失。pytorch_msssim
来计算 SSIM(需要先安装 pip install pytorch_msssim
)。evaluators/pansharpening_evaluator.py
请根据你自己的文件结构,将对应部分放到相应的目录,例如 losses/
、utils/
、evaluators/
等。若你不想使用 SSIM,可只在 NightLightLoss
里保留 L1 部分即可。
main.py
python# main.py import argparse import os import torch import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import random import json import matplotlib.pyplot as plt # 请根据你的实际目录进行导入 from pannet import PanNet from utils.landsat8_dataset import Landsat8Dataset from utils.misc import get_patch_configs from evaluators.pansharpening_evaluator import evaluate # 使用新的自定义损失函数 from losses.night_light_loss import NightLightLoss def plot_losses(train_losses, val_losses, out_dir): plt.figure() plt.plot(train_losses, label='Training Loss') plt.plot(val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('Training and Validation Loss') plt.savefig(os.path.join(out_dir, 'loss.png')) plt.close() def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--epochs', type=int, default=5) parser.add_argument('--alpha', type=float, default=0.001) parser.add_argument('--beta1', type=float, default=0.9) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--eps', type=float, default=1e-8) parser.add_argument('--eta', type=float, default=1.0) parser.add_argument('--weight_decay_rate', type=float, default=1e-4) parser.add_argument('--amsgrad', action='store_true', help='Use AMSGrad variant of AdamW') parser.add_argument('--data_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/train/') parser.add_argument('--out_model_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/out_model/') parser.add_argument('--no_plot', action='store_false', dest='plot', help='Disable plotting of loss') parser.add_argument('--random_seed', default=1, type=int, help='Random seed for choosing patches') args = parser.parse_args() os.makedirs(args.out_model_dir, exist_ok=True) # Set device device = torch.device('cuda' if (args.gpu >= 0 and torch.cuda.is_available()) else 'cpu') # Prepare datasets print('Preparing training data...') training_patch_configs, validation_patch_configs, _ = get_patch_configs( args.data_dir, training_data_rate=0.8, validation_data_rate=0.2, test_data_rate=0.0, patch_width=16, patch_height=16, num_patches=1000, random_seed=args.random_seed ) training_dataset = Landsat8Dataset(training_patch_configs) validation_dataset = Landsat8Dataset(validation_patch_configs) print('The number of training patches:', len(training_dataset)) print('The number of validation patches:', len(validation_dataset)) # Data loaders train_loader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2) print(f'Using device: {device}') # Initialize model model = PanNet(out_channels=1).to(device) # Define loss and optimizer # 这里将原先 torch.nn.MSELoss() 替换为新的 NightLightLoss # alpha 值可根据需要在 0.5~0.9 之间调整;默认为 0.8 criterion = NightLightLoss(alpha=0.8).to(device) optimizer = optim.Adam( model.parameters(), lr=args.alpha, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay_rate, amsgrad=args.amsgrad ) best_val_loss = float('inf') train_losses = [] val_losses = [] for epoch in range(1, args.epochs + 1): model.train() running_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad() outputs = model(inputs) # NightLightLoss 直接把模型输出和真实值作为输入 loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() if (batch_idx + 1) % 100 == 0: print(f'Epoch [{epoch}/{args.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.6f}') avg_train_loss = running_loss / len(train_loader) train_losses.append(avg_train_loss) print(f'Epoch [{epoch}/{args.epochs}], Average Training Loss: {avg_train_loss:.6f}') # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) val_losses.append(avg_val_loss) print(f'Epoch [{epoch}/{args.epochs}], Validation Loss: {avg_val_loss:.6f}') # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_model_path = os.path.join(args.out_model_dir, 'best_model.pth') torch.save(model.state_dict(), best_model_path) print(f'Best model saved at epoch {epoch}') # Save final model final_model_path = os.path.join(args.out_model_dir, f'model_epoch-{args.epochs}.pth') torch.save(model.state_dict(), final_model_path) print(f'Final model saved at epoch {args.epochs}') # Plot losses if args.plot: plot_losses(train_losses, val_losses, args.out_model_dir) print(f'Loss plot saved in {args.out_model_dir}/loss.png') print('Training completed.') if __name__ == '__main__': main()
losses/night_light_loss.py
注意:如果你不想用 SSIM,只想用 L1,可将 SSIM 部分删掉即可。
若使用 SSIM,需要先安装:bashpip install pytorch_msssim
python# losses/night_light_loss.py import torch import torch.nn as nn from pytorch_msssim import ssim class NightLightLoss(nn.Module): """ 一个简单的 L1 + SSIM 组合损失: total_loss = alpha * L1 + (1 - alpha) * (1 - ssim) alpha 可以根据需要进行调参(0.5~0.9)。 """ def __init__(self, alpha=0.8): super(NightLightLoss, self).__init__() self.alpha = alpha self.l1_loss = nn.L1Loss() def forward(self, pred, target): # 计算 L1 Loss l1_val = self.l1_loss(pred, target) # 计算 SSIM (取值越高,图像越相似,所以这里是 1 - SSIM) # data_range 可以根据图像像素值范围调整,一般夜光数据可先标准化到 [0,1] ssim_val = ssim(pred, target, data_range=1.0) total_loss = self.alpha * l1_val + (1.0 - self.alpha) * (1.0 - ssim_val) return total_loss
如上,如果你的数据像素值范围不是 [0,1],请务必设置好 data_range=
参数(例如最大像素值为 255 则设 data_range=255
)。
evaluators/pansharpening_evaluator.py
如果目前只想维持原有的评估流程(用 MSE 作参考指标),那就不需要大改。也可以根据实际需要,引入其他指标(PSNR、SSIM 等)。下面代码保持和你给出的示例基本一致,仅示例性贴出:
python# evaluators/pansharpening_evaluator.py import torch import numpy as np def evaluate(model, data_loader, device): model.eval() mse_list = [] with torch.no_grad(): for inputs, targets in data_loader: inputs = inputs.to(device) targets = targets.to(device) outputs = model(inputs) mse = torch.mean((outputs - targets) ** 2).item() mse_list.append(mse) avg_mse = np.mean(mse_list) return avg_mse
损失函数选择建议
使用方法
NightLightLoss
中设置合适的 alpha
;pytorch_msssim
;alpha=1.0
且不计算 ssim。重要注意点
[0, 1]
,在计算 SSIM 时要调整 data_range
;这样,你就可以更好地在夜光图像超分辨率重建中兼顾像素级一致性与夜光图像的细节结构,从而取得更好的定量指标与可视化效果。祝研究/项目顺利!