Shap-E 微调策略
Создано: 18 января 2025 г.
Создано: 18 января 2025 г.
Integrating components
I’m working through combining the shap-E code with a diffusion model using voxel encoder latents and text embeddings for a complete training loop.
Enhancing integration
I’m refining the user's code to ensure smooth execution, considering model adjustments and environment references. Recognizing potential issues, like "text300M" configuration and voxel size assumptions, is crucial.
Ensuring code accuracy
I’m focusing on aligning voxel dimensions and clarifying definitions like "VOXEL_SIZE" and "EMPTY_BLOCK_ID" to bolster code robustness.
Refining the strategy
I’m honing a concise, tested code snippet, emphasizing adaptability for different data shapes and block IDs. Clearly defined constants like "VOXEL_SIZE=64" and "EMPTY_BLOCK_ID=0" are crucial.
Thought about Shap-E fine-tuning strategy,持续 7 秒
下面给出一个较为完整的微调示例代码,将前面讨论的各个部分整合在一起。需要注意的是,具体的超参数(如批量大小、学习率、网络结构、体素分辨率、方块类型数量等)都需要根据你的实际数据规模和实验需求来进行调整。本示例中的一些数值仅作参考。请在实际使用前先根据自己的环境和数据进行测试和修改。
import os import json import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset # --------------------------- # 如果你在使用 notebook 环境,请确保安装了: # pip install shap-e # pip install transformers # 并正确导入 shap_e # --------------------------- from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_config from shap_e.models.generation.transformer import Transformer, MLP, init_linear from shap_e.models.generation.perceiver import MultiheadCrossAttention, ResidualCrossAttentionBlock, SimplePerceiver from transformers import CLIPTokenizer, CLIPTextModel # ====================== # 一、数据集定义与加载 # ====================== class MinecraftDataset(Dataset): """ 示例数据集: - descriptions.pt: 存储文本描述,每条数据已经是一个tensor。 - voxels.pt: 存储体素信息,每条数据也是一个tensor。 确保这两个文件在同一个文件夹下,并且尺寸匹配。 """ def __init__(self, descriptions_path, voxels_path): # descriptions: 形如 [N, ...] 的文本描述,每条是一个tensor # voxels: 形如 [N, D, H, W] 或 [N, ...] 的体素数据 self.descriptions = torch.load(descriptions_path) self.voxels = torch.load(voxels_path) assert len(self.descriptions) == len(self.voxels), \ f"描述数量({len(self.descriptions)})与体素数据数量({len(self.voxels)})不一致" def __len__(self): return len(self.descriptions) def __getitem__(self, idx): # 返回一个dict,包含文本和体素数据 # 你可以根据需要进一步处理 return { "text": self.descriptions[idx], # shape示例: [tokens], 具体依赖你的预处理过程 "voxels": self.voxels[idx], # shape示例: [D, H, W] 或 [x, y, z, block_id]等 } # ----------------------------- # 这里假设你前面已经有了将 JSONL 转为 .pt 文件的脚本 # 对应文件路径如下: # ----------------------------- output_folder = 'processed_output' descriptions_path = os.path.join(output_folder, "descriptions.pt") voxels_path = os.path.join(output_folder, "voxels.pt") # ----------------------------- # 创建数据集和 DataLoader # 注意根据自己 GPU 的显存,适当调整 batch_size # ----------------------------- dataset = MinecraftDataset(descriptions_path, voxels_path) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # 示例 batch_size=4 # ====================== # 二、模型定义 # ====================== # 1. 文本编码器 (CLIP + 投影层) class MinecraftCLIPTextEncoder(nn.Module): """ 将文本转换为一个 embedding,用于提供给扩散模型作为条件(cond)。 这里使用 CLIPTextModel 的输出,然后投影到更小的 Minecraft vocab embedding 中。 你可以根据需要对输出维度进行调整。 """ def __init__(self, clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=512): super().__init__() self.clip = CLIPTextModel.from_pretrained(clip_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) hidden_size = self.clip.config.hidden_size # 投影层:将 CLIP hidden_size 投影到 minecraft_vocab_size self.proj = nn.Linear(hidden_size, minecraft_vocab_size) def forward(self, text_list): """ text_list: List[str] 或已经 tokenized 的 id。如果是List[str],需要先 tokenizer. 为了兼容性,这里演示直接把字符串列表传进来,然后在内部调用 tokenizer。 也可以在外部做 tokenizer,再把输入转成 token 张量。 """ # 统一放到与 CLIP 相同的设备上 inputs = self.tokenizer(text_list, return_tensors="pt", padding=True, truncation=True) # 将输入移动到 CLIP 模型所在设备 inputs = {k: v.to(self.clip.device) for k, v in inputs.items()} outputs = self.clip(**inputs) # outputs.pooler_output: [batch_size, hidden_size] # 做一个线性投影 projected = self.proj(outputs.pooler_output) # projected: [batch_size, minecraft_vocab_size] return projected # 2. VoxelEncoder (示例: 3D 卷积网络) class VoxelEncoder(nn.Module): """ 将体素数据编码为一个潜向量 latent。 这里假设体素数据输入形状为 [batch_size, 1, 64, 64, 64]。 你需要根据实际体素分辨率进行修改。 """ def __init__(self, d_latent=512): super().__init__() self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1) self.fc = nn.Linear(128 * 8 * 8 * 8, d_latent) # 如果输入是 64^3,卷积后是 8^3 def forward(self, x): # x: [batch_size, 1, 64, 64, 64] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.shape[0], -1) # 拉平 x = self.fc(x) # [batch_size, d_latent] return x # 3. VoxelDecoder (示例: 3D 反卷积网络) class VoxelDecoder(nn.Module): """ 将潜向量解码回体素的 logits。 假设需要输出 [batch_size, n_block_types, 64, 64, 64], 其中 n_block_types 为可选的方块种类数。 """ def __init__(self, d_latent=512, output_size=64, n_block_types=600): super().__init__() self.fc = nn.Linear(d_latent, 128 * 8 * 8 * 8) self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose3d(32, n_block_types, kernel_size=4, stride=2, padding=1) self.output_size = output_size def forward(self, x): # x: [batch_size, d_latent] x = self.fc(x) x = x.view(x.shape[0], 128, 8, 8, 8) x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = self.deconv3(x) # [batch_size, n_block_types, 64, 64, 64] return x # 4. 加载/配置 Shap-E 的扩散模型 # 注意:需要先下载相关 checkpoint,或使用自带的 config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载一个预设的扩散配置,这里以 "text300M" 为例 config = load_config("text300M", device=device) # 如果要微调,需要在这里调整一些超参数 # 举例:将输入通道或输出通道修改为 512(因为我们要处理 512维 latent) config["inner"]["input_channels"] = 512 config["inner"]["n_ctx"] = 1 config["inner"]["output_channels"] = 512 * 2 # 每个 x_t 预测均值和方差,故 512*2 diffusion_model = diffusion_from_config(config).to(device) # ====================== # 三、模型初始化 # ====================== # 文本编码器 text_encoder = MinecraftCLIPTextEncoder( clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=512 ).to(device) # 体素编码器 voxel_encoder = VoxelEncoder(d_latent=512).to(device) # 体素解码器 voxel_decoder = VoxelDecoder( d_latent=512, output_size=64, n_block_types=600 # 假设有600种可能的方块类型 ).to(device) # 选定要训练的参数 # 如果你想让 CLIP 参与微调,可将 text_encoder.parameters() 一并加入 params = list(diffusion_model.parameters()) \ + list(voxel_encoder.parameters()) \ + list(voxel_decoder.parameters()) # 如果想联动微调文本编码器,把它加进来即可: # params += list(text_encoder.parameters()) optimizer = optim.AdamW(params, lr=1e-4, betas=(0.9, 0.999)) # ====================== # 四、训练循环 # ====================== num_epochs = 10 print_interval = 50 # 每隔多少 step 打印一次 loss save_interval = 2 # 每隔多少 epoch 保存一次模型 # 一些示例常量(仅用于采样时的示例) VOXEL_SIZE = 64 # 假设体素大小是 64^3 EMPTY_BLOCK_ID = 0 # 假设 0 表示没有方块 for epoch in range(num_epochs): for step, batch in enumerate(dataloader): diffusion_model.train() voxel_encoder.train() voxel_decoder.train() text_encoder.train() # 如果不微调文本编码器,可改成 .eval() optimizer.zero_grad() # ============ 数据准备 ============ # batch["text"] 形如 [batch_size, ...],也有可能是一个 list[str] # 如果已经是 token_id,需要注意是否要转为 GPU # 这里假设 batch["text"] 是一个已经在 CPU 中的张量或 list[str] # 如果 batch["text"] 是一个 List[str],则直接传给 forward 即可: # text_embeddings = text_encoder(batch["text"]) # 如果 batch["text"] 是预先处理好的张量(IDs),需要自行处理: # text_embeddings = text_encoder.forward_with_ids(batch["text"]) # 需自行定义 # 下面仅给出一种可能的做法:假设 batch["text"] 是 List[str] raw_text = batch["text"] # Python list if isinstance(raw_text, torch.Tensor): # 如果本身是tensor (例如 token_ids),请自行处理 pass text_embeddings = text_encoder(raw_text) # [batch_size, 512] text_embeddings = text_embeddings.unsqueeze(1) # [batch_size, 1, 512] # 体素数据:假设 voxel shape=[batch_size, D, H, W] voxels = batch["voxels"] # 为了兼容 3D卷积,一般需要变成 [batch_size, 1, D, H, W] # 如果原来就有 channel 维度,请自行调整 if len(voxels.shape) == 4: # [batch_size, D, H, W] -> [batch_size, 1, D, H, W] voxels_input = voxels.unsqueeze(1).float() else: # 其他格式需要自己再做处理,比如 [x,y,z,block_id]等 raise ValueError("体素数据的shape与预期不符,请根据你的预处理进行修改。") voxels_input = voxels_input.to(device) # ============ 编码体素,得到潜向量 ============ encoded_voxels = voxel_encoder(voxels_input) # [batch_size, 512] # ============ 准备扩散模型所需的时间步 t ============ batch_size = encoded_voxels.shape[0] t = torch.randint(0, diffusion_model.num_timesteps, (batch_size,), device=device) # ============ 扩散模型前向计算 (x_start = encoded_voxels) ============ # 这里把 text_embeddings 当做 cond # diffusion_model.training_losses(...) 会返回一个包含各种损失值的字典 terms = diffusion_model.training_losses( model=diffusion_model.wrapped, # shap-e 内部实际用 wrapped 做推理 x_start=encoded_voxels, t=t, model_kwargs={"cond": text_embeddings} ) loss = terms["loss"].mean() # ============ 反向传播与优化 ============ loss.backward() optimizer.step() if step % print_interval == 0: print(f"[Epoch {epoch}/{num_epochs} | Step {step}] loss: {loss.item():.4f}") # ============ 每隔 save_interval 个 epoch 保存模型 ============ if (epoch + 1) % save_interval == 0: save_path = os.path.join(output_folder, f"epoch_{epoch+1}") os.makedirs(save_path, exist_ok=True) torch.save(diffusion_model.state_dict(), os.path.join(save_path, "diffusion_model.pt")) torch.save(voxel_encoder.state_dict(), os.path.join(save_path, "voxel_encoder.pt")) torch.save(voxel_decoder.state_dict(), os.path.join(save_path, "voxel_decoder.pt")) # 如果也训练了文本编码器,则保存: torch.save(text_encoder.state_dict(), os.path.join(save_path, "text_encoder.pt")) print(f"模型已保存到 {save_path}") # ============ (可选)在每个 epoch 后做一次采样测试 ============ diffusion_model.eval() voxel_encoder.eval() voxel_decoder.eval() text_encoder.eval() sample_text = ["a redstone torch"] # 示例文本 with torch.no_grad(): # 1) 得到文本条件 sample_text_embeddings = text_encoder(sample_text).to(device) # [1, 512] sample_text_embeddings = sample_text_embeddings.unsqueeze(1) # [1, 1, 512] # 2) 在潜扩散空间中进行采样 # Shap-E 的代码中有一个 sample_latents 或者类似的方法, # 不同版本可能略有差异,需要根据 shap-e 版本进行适配。 sampled_latents = diffusion_model.sample_latents( batch_size=1, model=diffusion_model.wrapped, diffusion=diffusion_model, model_kwargs={"cond": sample_text_embeddings}, guidance_scale=1.0, clip_denoised=False, use_fp16=False, use_karras=True, karras_steps=64, sigma_min=diffusion_model.sigma_min, sigma_max=diffusion_model.sigma_max, s_churn=0, device=device, progress=True, ) # [1, 512] # 3) 由 voxel_decoder 解码 # 获得每个体素位置的方块类别 logits # shape: [batch_size=1, n_block_types, 64, 64, 64] sampled_voxels_logits = voxel_decoder(sampled_latents) sampled_voxels_ids = torch.argmax(sampled_voxels_logits, dim=1) # [1, 64, 64, 64] sampled_voxels_ids = sampled_voxels_ids.squeeze(0).cpu().numpy() # [64, 64, 64] # 4) 将生成结果保存到文本文件(仅做简单演示) sample_voxel_text_lines = [] for x in range(VOXEL_SIZE): for y in range(VOXEL_SIZE): for z in range(VOXEL_SIZE): block_id = sampled_voxels_ids[x, y, z] if block_id != EMPTY_BLOCK_ID: sample_voxel_text_lines.append(f"{block_id} {x} {y} {z}") sample_voxel_text = "\n".join(sample_voxel_text_lines) sample_voxel_txt_path = os.path.join(output_folder, f"sample_epoch_{epoch}.txt") with open(sample_voxel_txt_path, 'w', encoding='utf-8') as txt_file: txt_file.write(sample_voxel_text) print(f"[Epoch {epoch}] 采样完成,体素数据已保存到:'{sample_voxel_txt_path}'") print("训练完成。")
数据加载与预处理
.pt
文件(即 descriptions.pt
和 voxels.pt
),并确保两者条目数量相同。[batch_size, 1, D, H, W]
;如果是其他格式,需要根据自己的预处理方式来改写 VoxelEncoder
。模型结构
MinecraftCLIPTextEncoder
:演示了如何使用预训练的 CLIP 文本模型,将输出投影到一个更小的维度(本例中为 512)。VoxelEncoder
和 VoxelDecoder
:仅为示例的 3D 卷积/反卷积网络,需要根据实际需求改动,比如更深层次的网络、Residual 模块、或使用其他 3D 表征。diffusion_model
:加载了 shap-e
内置的 text300M
配置,并手动修改了 input_channels
/ output_channels
等。你需要在本地先安装好 shap-e
并下载对应的模型权重(或自行训练、微调)。训练循环
t
:随机采样 [0, diffusion_model.num_timesteps)
,然后调用 training_losses
。terms = diffusion_model.training_losses(...)
返回字典,terms["loss"]
是默认的总损失。AdamW
,学习率 1e-4
仅作参考。实际需结合数据规模和 batch_size 等进行调参。采样与解码
diffusion_model.sample_latents(...)
在潜空间采样出潜向量,再通过 VoxelDecoder
解码为 [n_block_types, 64, 64, 64]
的 logits,并取 argmax
作为预测的方块 ID。其他注意事项
text_encoder.parameters()
也加入优化器参数列表,并在训练循环中把 with torch.no_grad():
去掉,确保梯度正常回传。batch_size
、缩减网络规模、或降低体素分辨率。torch.nn.parallel.DistributedDataParallel
。CLIPTokenizer
太大,或你的文本描述过长,需要注意截断、分词长度等问题。以上即为一个「相对完整」的微调示例代码框架,实际项目中通常还会加入更多的日志管理、指标计算、模型评估和可视化等模块。请根据自己的实际需求进行精细化改动和优化。祝你实验顺利!