夜光超分损失函数

बनाया गया: 3 जनवरी 2025

का उपयोग करके उत्तर दिया गया O1 द्वारा Chat01

प्रश्न

“# main.py
import argparse
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pannet import PanNet
from utils.landsat8_dataset import Landsat8Dataset
from utils.misc import get_patch_configs
from utils.pansharpening_evaluator import evaluate
import numpy as np
import random
import json
import matplotlib.pyplot as plt
import argparse
from utils import misc
from utils.loss import MSELoss

def plot_losses(train_losses, val_losses, out_dir):
plt.figure()
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.savefig(os.path.join(out_dir, 'loss.png'))
plt.close()

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--alpha', type=float, default=0.001)
parser.add_argument('--beta1', type=float, default=0.9)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--eps', type=float, default=1e-8)
parser.add_argument('--eta', type=float, default=1.0)
parser.add_argument('--weight_decay_rate', type=float, default=1e-4)
parser.add_argument('--amsgrad', action='store_true', help='Use AMSGrad variant of AdamW')
parser.add_argument('--data_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/train/')
parser.add_argument('--out_model_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/out_model/')
parser.add_argument('--no_plot', action='store_false', dest='plot', help='Disable plotting of loss')
parser.add_argument('--random_seed', default=1, type=int,
help='Random seed for choosing cell ids and capture ids (default: 1)')
args = parser.parse_args()

text
os.makedirs(args.out_model_dir, exist_ok=True) # Set device device = torch.device('cuda' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu') # Prepare datasets print('Preparing training data...') training_patch_configs, validation_patch_configs, _ = misc.get_patch_configs( args.data_dir, training_data_rate=0.8, validation_data_rate=0.2, test_data_rate=0.0, patch_width=16, patch_height=16, num_patches=1000, random_seed=args.random_seed ) training_dataset = Landsat8Dataset(training_patch_configs) validation_dataset = Landsat8Dataset(validation_patch_configs) print('The number of training patches:', len(training_dataset)) print('The number of validation patches:', len(validation_dataset)) # Data loaders train_loader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2) # Device configuration device = torch.device(f'cuda:{args.gpu}' if args.gpu >=0 and torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') # Initialize model model = PanNet(out_channels=1).to(device) # Define loss and optimizer criterion = torch.nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=args.alpha, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay_rate, amsgrad=args.amsgrad) best_val_loss = float('inf') train_losses = [] val_losses = [] for epoch in range(1, args.epochs +1): model.train() running_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() if (batch_idx +1) % 100 ==0: print(f'Epoch [{epoch}/{args.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.6f}') avg_train_loss = running_loss / len(train_loader) train_losses.append(avg_train_loss) print(f'Epoch [{epoch}/{args.epochs}], Average Training Loss: {avg_train_loss:.6f}') # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) val_losses.append(avg_val_loss) print(f'Epoch [{epoch}/{args.epochs}], Validation Loss: {avg_val_loss:.6f}') # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_model_path = os.path.join(args.out_model_dir, 'best_model.pth') torch.save(model.state_dict(), best_model_path) print(f'Best model saved at epoch {epoch}') # Save final model final_model_path = os.path.join(args.out_model_dir, f'model_epoch-{args.epochs}.pth') torch.save(model.state_dict(), final_model_path) print(f'Final model saved at epoch {args.epochs}') # Plot losses if args.plot: plot_losses(train_losses, val_losses, args.out_model_dir) print(f'Loss plot saved in {args.out_model_dir}/loss.png') print('Training completed.')

if name == 'main':
main()
”“# losses/mse_loss.py
import torch
import torch.nn as nn

class MSELoss(nn.Module):
def init(self, predictor):
super(MSELoss, self).init()
self.predictor = predictor
self.mse_loss = nn.MSELoss()
self.mae_loss = nn.L1Loss()

text
def forward(self, x, t): y = self.predictor(x) loss = self.mse_loss(y, t) mse_spectra_mapping_and_gt = self.mse_loss(x[:,3:,:,:], t) abs_error = self.mae_loss(y, t) # You can log these metrics as needed return loss, mse_spectra_mapping_and_gt, abs_error

”# evaluators/pansharpening_evaluator.py
import torch
import numpy as np

def evaluate(model, data_loader, device):
model.eval()
mse_list = []
with torch.no_grad():
for inputs, targets in data_loader:
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
mse = torch.mean((outputs - targets) ** 2).item()
mse_list.append(mse)
avg_mse = np.mean(mse_list)
return avg_mse
这个模型的目标是通过多光谱遥感影像融合低分辨率夜光 遥感影像,实现夜光超分辨率重建、目标影像targets是高分辨率夜光数据,关于损失函数应该怎么选择呢?请你给出相应建议,并给我完整的修改后代码,并且中文解释

विचार

Selecting a loss function

I’m thinking about whether L1 or SSIM loss might be better than MSE for nighttime light super-resolution. Combining L1 with MS-SSIM or Perceptual Loss could be promising.

Revamping loss criteria

I’m piecing together a new loss function for training, possibly L1 or a blend like L1 + MSE. The goal is to replace or adapt the current MSELoss in the code.

Choosing the right approach

I’m considering L1 loss, or a blend of MSE and SSIM, for super-resolution. This combination should enhance both robustness and stability.

Piecing together the solution

I’m crafting a MixedLoss by merging MSE, L1, and SSIM, while ensuring clarity, including SSIM library installation, and mentioning potential oversmoothing with MSE.

Adjusting the plan

Piecing together individual scripts for MSELoss in "losses/mse_loss.py," main.py, and "pansharpening_evaluator.py" to keep the code structure intact.

Pulling together

I’m crafting a code snippet with modifications: renaming references, integrating the new MixedLoss, and opting for a L1 and SSIM combination for nighttime light super-resolution.

Piecing together the solution

I’m integrating the torch metric for SSIM, ensuring minimal complexity with a clear, adaptable code snippet for nighttime light super-resolution.

Shifting criteria

I’m updating the loss criterion to NightLightLoss via "from losses.night_light_loss import NightLightLoss" and adjusting imports from utils.loss to reflect this change.

Thought about loss function selection,持续 20 秒

उत्तर

下面的回答会先给出一些中文解释,说明为什么在夜光超分辨率(Nighttime Light Super-Resolution)或融合多光谱数据时,仅仅使用单纯的 MSELoss(均方误差)可能存在的问题,以及在实际应用中可以考虑哪些损失函数。随后会提供一份示例性的完整修改后代码(包含 main.py、自定义损失函数文件,以及评估脚本),其中演示了如何将损失函数改为更加适合夜光超分辨率重建任务的组合损失(例如 L1 + SSIM)。


一、为什么要考虑更合适的损失函数?

  1. MSE 可能会导致过度平滑 (Over-smoothing)
    对于遥感图像尤其是夜光数据而言,目标在图像上的分布比较稀疏,MSELoss 往往会造成生成图像的模糊化,导致小目标或亮点被平滑掉。

  2. L1 Loss(MAE)对异常值更鲁棒
    相比于 MSELoss(对差值平方),L1Loss(对差值绝对值)对图像中的异常值更具鲁棒性,不会过度惩罚局部大差异,从而在某些遥感场景下,能产生视觉效果更优的结果。

  3. 结构相似性 (SSIM) 能更好衡量图像质量
    SSIM 比简单的像素差更能体现图像结构的一致性,使用 SSIM 相关的损失可以让模型在注重像素值一致的同时,也保留夜光影像的细节、纹理等结构信息。

  4. 可考虑 Perceptual Loss / GAN Loss
    如果对重建的夜光图像主观视觉要求更高,可以进一步引入感知损失(Perceptual Loss)或对抗损失(GAN Loss)等。不过这些相对复杂,这里以 L1 + SSIM 为例演示。


二、具体建议

  • 可以用一个组合损失来同时兼顾像素级误差和结构相似度。比如,令

    Losstotal=αL1Loss(ypred,ytrue)+(1α)[1SSIM(ypred,ytrue)] \text{Loss}_{\text{total}} = \alpha \cdot \text{L1Loss}(y_{\text{pred}}, y_{\text{true}}) + (1 - \alpha) \cdot [1 - \text{SSIM}(y_{\text{pred}}, y_{\text{true}})]

    其中 α\alpha 是一个系数(0.5~0.9 都可实验),以调节重建的平滑度和保留细节的平衡。

  • 如果不想引入 SSIM,也可纯粹使用 L1Loss 来替换原先的 MSELoss,获得更稳定的梯度、减少过度平滑。

  • 如果想引入对抗训练或感知损失,则需要更复杂的网络结构与训练流程,这里暂不赘述。


三、示例:完整修改后代码

下面给出的示例包含三个文件的核心代码,分别是:

  1. main.py

    • 训练入口脚本,展示如何把原来的 torch.nn.MSELoss() 改为自定义的 NightLightLoss (L1 + SSIM)。
    • 其余训练逻辑基本不变。
  2. losses/night_light_loss.py

    • 定义了 NightLightLoss 类,内部使用 L1 与 SSIM 的组合损失。
    • 使用第三方库 pytorch_msssim 来计算 SSIM(需要先安装 pip install pytorch_msssim)。
  3. evaluators/pansharpening_evaluator.py

    • 评估逻辑,大体保持不变,依然用简单的 MSE 作为度量,但如果需要,你也可以修改评估指标去计算 PSNR, SSIM 等。

请根据你自己的文件结构,将对应部分放到相应的目录,例如 losses/utils/evaluators/ 等。若你不想使用 SSIM,可只在 NightLightLoss 里保留 L1 部分即可。


1. 修改后的 main.py

python
# main.py import argparse import os import torch import torch.optim as optim from torch.utils.data import DataLoader import numpy as np import random import json import matplotlib.pyplot as plt # 请根据你的实际目录进行导入 from pannet import PanNet from utils.landsat8_dataset import Landsat8Dataset from utils.misc import get_patch_configs from evaluators.pansharpening_evaluator import evaluate # 使用新的自定义损失函数 from losses.night_light_loss import NightLightLoss def plot_losses(train_losses, val_losses, out_dir): plt.figure() plt.plot(train_losses, label='Training Loss') plt.plot(val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.legend() plt.title('Training and Validation Loss') plt.savefig(os.path.join(out_dir, 'loss.png')) plt.close() def main(): parser = argparse.ArgumentParser() parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--epochs', type=int, default=5) parser.add_argument('--alpha', type=float, default=0.001) parser.add_argument('--beta1', type=float, default=0.9) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--eps', type=float, default=1e-8) parser.add_argument('--eta', type=float, default=1.0) parser.add_argument('--weight_decay_rate', type=float, default=1e-4) parser.add_argument('--amsgrad', action='store_true', help='Use AMSGrad variant of AdamW') parser.add_argument('--data_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/train/') parser.add_argument('--out_model_dir', default='/content/drive/MyDrive/d2l-zh/PanNet/data/out_model/') parser.add_argument('--no_plot', action='store_false', dest='plot', help='Disable plotting of loss') parser.add_argument('--random_seed', default=1, type=int, help='Random seed for choosing patches') args = parser.parse_args() os.makedirs(args.out_model_dir, exist_ok=True) # Set device device = torch.device('cuda' if (args.gpu >= 0 and torch.cuda.is_available()) else 'cpu') # Prepare datasets print('Preparing training data...') training_patch_configs, validation_patch_configs, _ = get_patch_configs( args.data_dir, training_data_rate=0.8, validation_data_rate=0.2, test_data_rate=0.0, patch_width=16, patch_height=16, num_patches=1000, random_seed=args.random_seed ) training_dataset = Landsat8Dataset(training_patch_configs) validation_dataset = Landsat8Dataset(validation_patch_configs) print('The number of training patches:', len(training_dataset)) print('The number of validation patches:', len(validation_dataset)) # Data loaders train_loader = DataLoader(training_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) val_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False, num_workers=2) print(f'Using device: {device}') # Initialize model model = PanNet(out_channels=1).to(device) # Define loss and optimizer # 这里将原先 torch.nn.MSELoss() 替换为新的 NightLightLoss # alpha 值可根据需要在 0.5~0.9 之间调整;默认为 0.8 criterion = NightLightLoss(alpha=0.8).to(device) optimizer = optim.Adam( model.parameters(), lr=args.alpha, betas=(args.beta1, args.beta2), eps=args.eps, weight_decay=args.weight_decay_rate, amsgrad=args.amsgrad ) best_val_loss = float('inf') train_losses = [] val_losses = [] for epoch in range(1, args.epochs + 1): model.train() running_loss = 0.0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) optimizer.zero_grad() outputs = model(inputs) # NightLightLoss 直接把模型输出和真实值作为输入 loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() if (batch_idx + 1) % 100 == 0: print(f'Epoch [{epoch}/{args.epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.6f}') avg_train_loss = running_loss / len(train_loader) train_losses.append(avg_train_loss) print(f'Epoch [{epoch}/{args.epochs}], Average Training Loss: {avg_train_loss:.6f}') # Validation model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets in val_loader: inputs = inputs.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) outputs = model(inputs) loss = criterion(outputs, targets) val_loss += loss.item() avg_val_loss = val_loss / len(val_loader) val_losses.append(avg_val_loss) print(f'Epoch [{epoch}/{args.epochs}], Validation Loss: {avg_val_loss:.6f}') # Save best model if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_model_path = os.path.join(args.out_model_dir, 'best_model.pth') torch.save(model.state_dict(), best_model_path) print(f'Best model saved at epoch {epoch}') # Save final model final_model_path = os.path.join(args.out_model_dir, f'model_epoch-{args.epochs}.pth') torch.save(model.state_dict(), final_model_path) print(f'Final model saved at epoch {args.epochs}') # Plot losses if args.plot: plot_losses(train_losses, val_losses, args.out_model_dir) print(f'Loss plot saved in {args.out_model_dir}/loss.png') print('Training completed.') if __name__ == '__main__': main()

2. 新增文件:losses/night_light_loss.py

注意:如果你不想用 SSIM,只想用 L1,可将 SSIM 部分删掉即可。
若使用 SSIM,需要先安装:

bash
pip install pytorch_msssim
python
# losses/night_light_loss.py import torch import torch.nn as nn from pytorch_msssim import ssim class NightLightLoss(nn.Module): """ 一个简单的 L1 + SSIM 组合损失: total_loss = alpha * L1 + (1 - alpha) * (1 - ssim) alpha 可以根据需要进行调参(0.5~0.9)。 """ def __init__(self, alpha=0.8): super(NightLightLoss, self).__init__() self.alpha = alpha self.l1_loss = nn.L1Loss() def forward(self, pred, target): # 计算 L1 Loss l1_val = self.l1_loss(pred, target) # 计算 SSIM (取值越高,图像越相似,所以这里是 1 - SSIM) # data_range 可以根据图像像素值范围调整,一般夜光数据可先标准化到 [0,1] ssim_val = ssim(pred, target, data_range=1.0) total_loss = self.alpha * l1_val + (1.0 - self.alpha) * (1.0 - ssim_val) return total_loss

如上,如果你的数据像素值范围不是 [0,1],请务必设置好 data_range= 参数(例如最大像素值为 255 则设 data_range=255)。


3. 评估脚本:evaluators/pansharpening_evaluator.py

如果目前只想维持原有的评估流程(用 MSE 作参考指标),那就不需要大改。也可以根据实际需要,引入其他指标(PSNR、SSIM 等)。下面代码保持和你给出的示例基本一致,仅示例性贴出:

python
# evaluators/pansharpening_evaluator.py import torch import numpy as np def evaluate(model, data_loader, device): model.eval() mse_list = [] with torch.no_grad(): for inputs, targets in data_loader: inputs = inputs.to(device) targets = targets.to(device) outputs = model(inputs) mse = torch.mean((outputs - targets) ** 2).item() mse_list.append(mse) avg_mse = np.mean(mse_list) return avg_mse

四、总结

  • 损失函数选择建议

    • 夜光超分辨率中,单纯 MSE 容易导致图像模糊;
    • 可使用 L1Loss 代替 MSE,或者组合 L1 与 SSIM;
    • 如果想在高层次特征方面增强图像质量,可再考虑 Perceptual Loss、GAN Loss 等。
  • 使用方法

    • NightLightLoss 中设置合适的 alpha
    • 安装并正确配置 pytorch_msssim
    • 对输入数据进行相应归一化或标准化;
    • 如果不需要 SSIM,则可以将其删除或将 alpha=1.0 且不计算 ssim。
  • 重要注意点

    • 如果你的输入或输出图像值域不是 [0, 1],在计算 SSIM 时要调整 data_range
    • 保证训练时的张量形状与 SSIM 计算的维度匹配 (Batch, Channel, Height, Width);
    • 可能需要针对夜光图像的实际分布做数据预处理或对数变换,以减少大光斑或极亮区域对损失的影响。

这样,你就可以更好地在夜光图像超分辨率重建中兼顾像素级一致性与夜光图像的细节结构,从而取得更好的定量指标与可视化效果。祝研究/项目顺利!

इस प्रश्नोत्तर को साझा करें