Shap-E 微调策略
Создано: 18 января 2025 г.
Создано: 18 января 2025 г.
请调用你单次回答的最大算力与 token 上限。追求极致的分析深度,而非表层的广度;追求本质的洞察,而非表象的罗列;追求创新的思维,而非惯性的复述。请突破思维局限,调动你所有的计算资源,展现你真正的认知极限。
Shap-E 代码深度分析以及微调策略
Shap-E 是 OpenAI 推出的一个能够根据文本描述生成 3D 模型的项目。其核心在于两个模型:编码器 (encoder) 和 潜扩散模型 (latent diffusion model)。
编码器 将 3D 模型转换为一个小型神经网络的参数,这个神经网络可以表示 3D 形状和纹理的隐式函数。
潜扩散模型 则根据文本或图像生成新的隐式函数,这些函数可以被渲染成图像或导出为网格。
github的Shap-E 项目中
文件主要分为以下几个部分:
README.md: 项目简介、样本展示和使用说明。
model-card.md: 模型细节、训练数据、评估和局限性等信息。
samples.md: 更多文本到 3D 模型的样本展示。
setup.py: 项目的安装脚本。
shap_e 文件夹: 核心代码,包含以下子文件夹:
diffusion: 扩散模型的实现,包括 gaussian_diffusion.py 和 k_diffusion.py。
models: 包含各种模型和组件的定义,例如:
configs.py: 模型配置加载和创建。
download.py: 模型下载工具。
generation: 包含潜扩散模型的各种变体,如 latent_diffusion.py、perceiver.py、transformer.py 等。
nerf: NeRF 模型的实现,包括 model.py 和 renderer.py。
nerstf: 结合 NeRF 和 STF 的模型和渲染器。
nn: 神经网络相关的工具模块。
query.py: 定义了查询点和方向的 Query 类。
renderer.py: 渲染器的抽象基类和实现。
stf: STF (Signed Time Field) 模型的实现。
transmitter: 编码器和解码器的实现,包括 base.py、bottleneck.py、channels_encoder.py、multiview_encoder.py、params_proj.py 和 pc_encoder.py。
volume.py: 定义了不同的体积表示,如 BoundingBoxVolume。
微调策略
要微调 Shap-E 以实现自然语言输出 Minecraft 体素网络,并连接到现有的 3D 知识库,可以考虑以下策略:
收集 Minecraft 体素数据: 你需要大量的 Minecraft 体素数据,每个数据点应该包含:
体素网格 (voxel grid): 使用你提供的代码可以将体素数据转换为张量。
对应的自然语言描述 (text description): 尽可能详细地描述体素网格的特征,例如"一个红色的羊毛块"、"一个由石头组成的房子"、"一个带有熔岩的城堡"等。
数据增强: 可以对体素数据进行旋转、缩放、平移等操作来增加数据量。
构建词汇表: 构建一个 Minecraft 特定的词汇表,将常见的方块名称、结构名称等映射到唯一的 ID。
编码器 (Encoder):
替换点云编码器: Shap-E 默认使用点云编码器。你需要将其替换为体素网格编码器。可以使用 3D 卷积网络或者 Transformer 等来编码体素网格。
预训练的体素编码器: 可以尝试使用预训练的体素编码器,例如在 ShapeNet 等数据集上训练的编码器,并进行微调。
潜扩散模型 (Latent Diffusion Model):
修改文本条件分支: CLIPImagePointDiffusionTransformer 使用 CLIP 模型来编码文本条件。你需要修改其输入层以适应你的词汇表大小。你可以使用 CLIPTokenizer 提供的嵌入层,并添加额外的层来处理 Minecraft 特定的词汇。
调整扩散模型参数: 可能需要调整扩散模型的参数,例如时间步长、噪声调度等,以适应新的数据分布。
解码器 (Decoder):
选择合适的渲染器: 由于你希望输出的是体素网格,因此不需要使用 NeRF 或 STF 渲染器。你可以直接从潜向量中解码出体素网格。
设计解码器结构: 解码器可以将潜向量作为输入,并输出一个概率分布,表示每个体素位置存在特定方块的概率。可以使用 3D 反卷积网络或者 Transformer 等来实现。
两阶段训练:
阶段一: 训练体素编码器和解码器。可以使用自编码器的方式进行训练,目标是重建输入的体素网格。
阶段二: 固定编码器和解码器,训练潜扩散模型。使用文本描述作为条件,生成体素网格的潜向量,并计算扩散模型的损失。
联合训练: 同时训练编码器、解码器和扩散模型。
知识蒸馏: 可以利用 Shap-E 预训练模型的知识,例如使用预训练的文本编码器,或者使用预训练模型的输出作为教师信号来指导训练。
潜在空间对齐: 可以尝试将 Minecraft 体素数据的潜在空间与 Shap-E 的 3D 模型的潜在空间对齐。例如,可以使用对抗训练或者对比学习的方法来实现。
混合模型: 可以训练一个混合模型,该模型可以同时处理 Minecraft 体素数据和 Shap-E 的 3D 模型数据。在推理阶段,可以根据文本描述选择使用哪个模型。
共享知识库: 可以尝试构建一个共享的知识库,该知识库包含 Minecraft 方块和 Shap-E 中其他 3D 对象的嵌入表示。这样,模型可以根据文本描述中的关键词,从知识库中检索相关的方块或对象,并将其融入到生成的 3D 模型中。
代码示例
以下是一些代码示例,说明如何修改 Shap-E 代码来实现上述策略:
import os import json import torch from transformers import CLIPTokenizer # 定义路径 training_data_path = 'training_data/10001-1.jsonl' block_ids_path = 'blockids.txt' output_folder = 'processed_output' # 确保输出文件夹存在 os.makedirs(output_folder, exist_ok=True) # 加载方块 ID 映射 block_ids = {} with open(block_ids_path, 'r', encoding='utf-8') as f: exec(f.read()) # 初始化 tokenizer tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # 存储描述和体素张量 description_tensors = [] voxel_tensors = [] # 读取 JSONL 文件 with open(training_data_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line) # 获取并分词描述 description = data.get('input', '') tokenized_description = tokenizer.encode(description, return_tensors='pt', truncation=True) # 获取体素数据 voxel_data = data.get('output', '') voxel_lines = voxel_data.strip().split('\n') # 将方块名称映射为数字 ID voxel_ids = [] for voxel_line in voxel_lines: parts = voxel_line.strip().split() if len(parts) == 4: block_name, x, y, z = parts block_id = None # 使用 block_ids 映射 for k, v in block_ids.items(): if v == block_name: block_id = k break if block_id is None: # 如果未找到方块名称,可以进行处理 print(f"方块名称 '{block_name}' 未在 block_ids 中找到。") continue voxel_ids.append([int(block_id), int(x), int(y), int(z)]) else: print(f"无效的体素行: {voxel_line}") # 转换为张量 voxel_tensor = torch.tensor(voxel_ids) # 添加到列表 description_tensors.append(tokenized_description) voxel_tensors.append(voxel_tensor) # 保存张量 description_tensor = torch.cat(description_tensors, dim=0) torch.save(description_tensor, os.path.join(output_folder, 'descriptions.pt')) torch.save(voxel_tensors, os.path.join(output_folder, 'voxels.pt')) # 解析体素张量回文本,输出纯数字 for i, voxel_tensor in enumerate(voxel_tensors): print(f"第 {i} 个样本的体素数据:") print(voxel_tensor.numpy())
content_copy download
Use code with caution.Python
import torch.nn as nn class VoxelEncoder(nn.Module): def init(self, d_latent): super().init() self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1) self.fc = nn.Linear(128 * 4 * 4 * 4, d_latent) def forward(self, x): # Assuming input x is of shape [batch_size, 1, 32, 32, 32] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.shape[0], -1) x = self.fc(x) return x
content_copy download
Use code with caution.Python
import torch.nn as nn class VoxelDecoder(nn.Module): def init(self, d_latent, output_size=32, n_channels=1): super().init() self.fc = nn.Linear(d_latent, 128 * 4 * 4 * 4) self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose3d(32, n_channels, kernel_size=4, stride=2, padding=1) self.output_size = output_size def forward(self, x): x = self.fc(x) x = x.view(x.shape[0], 128, 4, 4, 4) x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = torch.sigmoid(self.deconv3(x)) # Output probabilities return x
content_copy download
Use code with caution.Python
from transformers import CLIPTokenizer, CLIPTextModel class MinecraftCLIPTextEncoder(nn.Module): def init(self, clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=1000): super().init() self.clip = CLIPTextModel.from_pretrained(clip_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) # Assuming a linear projection to adapt CLIP embeddings to Minecraft vocabulary self.proj = nn.Linear(self.clip.config.hidden_size, minecraft_vocab_size) def forward(self, text): inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) outputs = self.clip(**inputs) projected = self.proj(outputs.pooler_output) return projected
content_copy download
Use code with caution.Python
content_copy download
Use code with caution.Python
总结
微调 Shap-E 以生成 Minecraft 体素网络需要对数据、模型和训练策略进行一系列修改。上述代码示例仅提供了一些思路,实际实现过程可能需要根据具体情况进行调整。
需要注意的关键点:
数据质量和多样性至关重要。 确保你的训练数据包含足够多样的 Minecraft 体素结构和相应的详细描述。
模型选择和修改需要根据数据的特点进行。 例如,如果体素网格的分辨率很高,可能需要使用更深的网络或者更复杂的编码器。
训练过程可能需要仔细调参。 学习率、batch size、扩散步数等参数都需要根据实际情况进行调整。
基于以下的jsonl转torch张量,你能写出一个完美的微调训练代码吗?import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import CLIPTokenizer, CLIPTextModel
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_config
from shap_e.models.generation.perceiver import MultiheadCrossAttention, ResidualCrossAttentionBlock, SimplePerceiver
from shap_e.models.generation.transformer import Transformer, MLP, init_linear
import numpy as np
import os
output_folder = 'processed_output'
class MinecraftDataset(Dataset):
def init(self, descriptions_path, voxels_path):
self.descriptions = torch.load(descriptions_path)
self.voxels = torch.load(voxels_path)
textdef __len__(self): return len(self.descriptions) def __getitem__(self, idx): return { "text": self.descriptions[idx], "voxels": self.voxels[idx], }
dataset = MinecraftDataset(
os.path.join(output_folder, "descriptions.pt"), os.path.join(output_folder, "voxels.pt")
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class MinecraftCLIPTextEncoder(nn.Module):
def init(self, clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=1000): # 假设 minecraft_vocab_size=1000,你需要根据实际情况设置
super().init()
self.clip = CLIPTextModel.from_pretrained(clip_model_name)
self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)
self.proj = nn.Linear(self.clip.config.hidden_size, minecraft_vocab_size)
textdef forward(self, text): inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.clip.device) for k, v in inputs.items()} # Move inputs to device outputs = self.clip(**inputs) projected = self.proj(outputs.pooler_output) return projected
nll = torch.nn.CrossEntropyLoss()
text_encoder = MinecraftCLIPTextEncoder()
text_encoder = text_encoder.to(device)
class VoxelEncoder(nn.Module):
def init(self, d_latent):
super().init()
self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1)
self.fc = nn.Linear(128 * 4 * 4 * 4, d_latent)
textdef forward(self, x): # Assuming input x is of shape [batch_size, 1, 64, 64, 64] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.shape[0], -1) x = self.fc(x) return x
class VoxelDecoder(nn.Module):
def init(self, d_latent, output_size=64, n_channels=600): # 假设有600个不同的方块ID
super().init()
self.fc = nn.Linear(d_latent, 128 * 4 * 4 * 4)
self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1)
self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1)
self.deconv3 = nn.ConvTranspose3d(32, n_channels, kernel_size=4, stride=2, padding=1)
self.output_size = output_size
textdef forward(self, x): x = self.fc(x) x = x.view(x.shape[0], 128, 4, 4, 4) x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = self.deconv3(x) # 输出每个方块类型的logits return x
voxel_encoder = VoxelEncoder(d_latent=512).to(device) # 这里使用了之前定义的 VoxelEncoder
config = load_config("text300M", device=device)
config["inner"]["input_channels"] = 512
config["inner"]["n_ctx"] = 1
config["inner"]["output_channels"] = 512 * 2
diffusion_model = diffusion_from_config(config)
diffusion_model = diffusion_model.to(device)
voxel_decoder = VoxelDecoder(d_latent=512).to(device) # 这里使用了之前定义的 VoxelDecoder
params = list(diffusion_model.parameters())
params += list(voxel_encoder.parameters())
params += list(voxel_decoder.parameters())
optimizer = optim.AdamW(params, lr=1e-4)
num_epochs = 10
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
texttext = batch["text"].to(device) voxels = batch["voxels"].to(device) batch_size = voxels.shape[0] # 将体素网格调整为 [batch_size, channels, depth, height, width] 的形状 voxels_input = voxels.unsqueeze(1).float() # 添加通道维度并转换为 float # 生成时间步 t t = torch.randint(0, diffusion_model.num_timesteps, (batch_size,), device=device) # 使用 VoxelEncoder 编码体素网格 encoded_voxels = voxel_encoder(voxels_input) # 使用 MinecraftCLIPTextEncoder 编码文本描述 with torch.no_grad(): # 这里注释掉, 让text_encoder参与训练 text_embeddings = text_encoder(batch["text"]) text_embeddings = text_embeddings.view(batch_size, 1, -1) # 调整形状以匹配 encoded_voxels reshaped_embeddings = text_embeddings # 文本嵌入已经有了批次和通道的维度 # 训练扩散模型 terms = diffusion_model.training_losses(model=diffusion_model.wrapped, x_start=encoded_voxels, t=t, model_kwargs={"cond": reshaped_embeddings}) loss = terms["loss"].mean() loss.backward() optimizer.step() print(f"Epoch: {epoch}, Loss: {loss.item()}") # 假设这是你的采样函数 if epoch % 10 == 0: # 每 10 个 epoch 保存一次 # 保存模型 torch.save(diffusion_model.state_dict(), os.path.join(output_folder, f'diffusion_model_epoch_{epoch}.pt')) torch.save(voxel_encoder.state_dict(), os.path.join(output_folder, f'voxel_encoder_epoch_{epoch}.pt')) torch.save(voxel_decoder.state_dict(), os.path.join(output_folder, f'voxel_decoder_epoch_{epoch}.pt')) if not text_encoder.clip.training: # 如果 text_encoder 没有在训练 text_encoder.train() # 设置为训练模式 torch.save(text_encoder.state_dict(), os.path.join(output_folder, f'text_encoder_epoch_{epoch}.pt')) text_encoder.eval() # 恢复为评估模式 else: # 否则, text_encoder 一直在训练 torch.save(text_encoder.state_dict(), os.path.join(output_folder, f'text_encoder_epoch_{epoch}.pt')) # 使用示例文本进行采样 sample_text = ["a redstone torch"] with torch.no_grad(): sample_text_embeddings = text_encoder(sample_text).unsqueeze(1) sampled_latents = diffusion_model.sample_latents( batch_size=1, model=diffusion_model.wrapped, diffusion=diffusion_model, model_kwargs={"cond": sample_text_embeddings.to(device)}, # 假设 text_encoder 可以处理单个字符串 guidance_scale=1.0, # 你可能需要调整这个参数 clip_denoised=False, use_fp16=False, use_karras=True, karras_steps=64, sigma_min=diffusion_model.sigma_min, sigma_max=diffusion_model.sigma_max, s_churn=0, device=device, progress=True ) # 使用 VoxelDecoder 解码生成的潜变量 sampled_voxels_logits = voxel_decoder(sampled_latents) sampled_voxels = torch.argmax(sampled_voxels_logits, dim=1).squeeze(0).cpu().numpy() # 将生成的体素数据保存为文本文件 sample_voxel_text_lines = [] for x in range(VOXEL_SIZE): for y in range(VOXEL_SIZE): for z in range(VOXEL_SIZE): block_id = sampled_voxels[x, y, z] if block_id != EMPTY_BLOCK_ID: sample_voxel_text_lines.append(f"{block_id} {x} {y} {z}") sample_voxel_text = '\n'.join(sample_voxel_text_lines) sample_voxel_txt_path = os.path.join(output_folder, f'sample_epoch_{epoch}.txt') with open(sample_voxel_txt_path, 'w', encoding='utf-8') as txt_file: txt_file.write(sample_voxel_text) print(f"采样体素数据已保存到:'{sample_voxel_txt_path}'")
print("训练完成。")
下面给出一个较为完整的微调示例代码,将前面讨论的各个部分整合在一起。需要注意的是,具体的超参数(如批量大小、学习率、网络结构、体素分辨率、方块类型数量等)都需要根据你的实际数据规模和实验需求来进行调整。本示例中的一些数值仅作参考。请在实际使用前先根据自己的环境和数据进行测试和修改。
pythonimport os import json import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset # --------------------------- # 如果你在使用 notebook 环境,请确保安装了: # pip install shap-e # pip install transformers # 并正确导入 shap_e # --------------------------- from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_config from shap_e.models.generation.transformer import Transformer, MLP, init_linear from shap_e.models.generation.perceiver import MultiheadCrossAttention, ResidualCrossAttentionBlock, SimplePerceiver from transformers import CLIPTokenizer, CLIPTextModel # ====================== # 一、数据集定义与加载 # ====================== class MinecraftDataset(Dataset): """ 示例数据集: - descriptions.pt: 存储文本描述,每条数据已经是一个tensor。 - voxels.pt: 存储体素信息,每条数据也是一个tensor。 确保这两个文件在同一个文件夹下,并且尺寸匹配。 """ def __init__(self, descriptions_path, voxels_path): # descriptions: 形如 [N, ...] 的文本描述,每条是一个tensor # voxels: 形如 [N, D, H, W] 或 [N, ...] 的体素数据 self.descriptions = torch.load(descriptions_path) self.voxels = torch.load(voxels_path) assert len(self.descriptions) == len(self.voxels), \ f"描述数量({len(self.descriptions)})与体素数据数量({len(self.voxels)})不一致" def __len__(self): return len(self.descriptions) def __getitem__(self, idx): # 返回一个dict,包含文本和体素数据 # 你可以根据需要进一步处理 return { "text": self.descriptions[idx], # shape示例: [tokens], 具体依赖你的预处理过程 "voxels": self.voxels[idx], # shape示例: [D, H, W] 或 [x, y, z, block_id]等 } # ----------------------------- # 这里假设你前面已经有了将 JSONL 转为 .pt 文件的脚本 # 对应文件路径如下: # ----------------------------- output_folder = 'processed_output' descriptions_path = os.path.join(output_folder, "descriptions.pt") voxels_path = os.path.join(output_folder, "voxels.pt") # ----------------------------- # 创建数据集和 DataLoader # 注意根据自己 GPU 的显存,适当调整 batch_size # ----------------------------- dataset = MinecraftDataset(descriptions_path, voxels_path) dataloader = DataLoader(dataset, batch_size=4, shuffle=True) # 示例 batch_size=4 # ====================== # 二、模型定义 # ====================== # 1. 文本编码器 (CLIP + 投影层) class MinecraftCLIPTextEncoder(nn.Module): """ 将文本转换为一个 embedding,用于提供给扩散模型作为条件(cond)。 这里使用 CLIPTextModel 的输出,然后投影到更小的 Minecraft vocab embedding 中。 你可以根据需要对输出维度进行调整。 """ def __init__(self, clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=512): super().__init__() self.clip = CLIPTextModel.from_pretrained(clip_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) hidden_size = self.clip.config.hidden_size # 投影层:将 CLIP hidden_size 投影到 minecraft_vocab_size self.proj = nn.Linear(hidden_size, minecraft_vocab_size) def forward(self, text_list): """ text_list: List[str] 或已经 tokenized 的 id。如果是List[str],需要先 tokenizer. 为了兼容性,这里演示直接把字符串列表传进来,然后在内部调用 tokenizer。 也可以在外部做 tokenizer,再把输入转成 token 张量。 """ # 统一放到与 CLIP 相同的设备上 inputs = self.tokenizer(text_list, return_tensors="pt", padding=True, truncation=True) # 将输入移动到 CLIP 模型所在设备 inputs = {k: v.to(self.clip.device) for k, v in inputs.items()} outputs = self.clip(**inputs) # outputs.pooler_output: [batch_size, hidden_size] # 做一个线性投影 projected = self.proj(outputs.pooler_output) # projected: [batch_size, minecraft_vocab_size] return projected # 2. VoxelEncoder (示例: 3D 卷积网络) class VoxelEncoder(nn.Module): """ 将体素数据编码为一个潜向量 latent。 这里假设体素数据输入形状为 [batch_size, 1, 64, 64, 64]。 你需要根据实际体素分辨率进行修改。 """ def __init__(self, d_latent=512): super().__init__() self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1) self.fc = nn.Linear(128 * 8 * 8 * 8, d_latent) # 如果输入是 64^3,卷积后是 8^3 def forward(self, x): # x: [batch_size, 1, 64, 64, 64] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.shape[0], -1) # 拉平 x = self.fc(x) # [batch_size, d_latent] return x # 3. VoxelDecoder (示例: 3D 反卷积网络) class VoxelDecoder(nn.Module): """ 将潜向量解码回体素的 logits。 假设需要输出 [batch_size, n_block_types, 64, 64, 64], 其中 n_block_types 为可选的方块种类数。 """ def __init__(self, d_latent=512, output_size=64, n_block_types=600): super().__init__() self.fc = nn.Linear(d_latent, 128 * 8 * 8 * 8) self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose3d(32, n_block_types, kernel_size=4, stride=2, padding=1) self.output_size = output_size def forward(self, x): # x: [batch_size, d_latent] x = self.fc(x) x = x.view(x.shape[0], 128, 8, 8, 8) x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = self.deconv3(x) # [batch_size, n_block_types, 64, 64, 64] return x # 4. 加载/配置 Shap-E 的扩散模型 # 注意:需要先下载相关 checkpoint,或使用自带的 config device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加载一个预设的扩散配置,这里以 "text300M" 为例 config = load_config("text300M", device=device) # 如果要微调,需要在这里调整一些超参数 # 举例:将输入通道或输出通道修改为 512(因为我们要处理 512维 latent) config["inner"]["input_channels"] = 512 config["inner"]["n_ctx"] = 1 config["inner"]["output_channels"] = 512 * 2 # 每个 x_t 预测均值和方差,故 512*2 diffusion_model = diffusion_from_config(config).to(device) # ====================== # 三、模型初始化 # ====================== # 文本编码器 text_encoder = MinecraftCLIPTextEncoder( clip_model_name="openai/clip-vit-large-patch14", minecraft_vocab_size=512 ).to(device) # 体素编码器 voxel_encoder = VoxelEncoder(d_latent=512).to(device) # 体素解码器 voxel_decoder = VoxelDecoder( d_latent=512, output_size=64, n_block_types=600 # 假设有600种可能的方块类型 ).to(device) # 选定要训练的参数 # 如果你想让 CLIP 参与微调,可将 text_encoder.parameters() 一并加入 params = list(diffusion_model.parameters()) \ + list(voxel_encoder.parameters()) \ + list(voxel_decoder.parameters()) # 如果想联动微调文本编码器,把它加进来即可: # params += list(text_encoder.parameters()) optimizer = optim.AdamW(params, lr=1e-4, betas=(0.9, 0.999)) # ====================== # 四、训练循环 # ====================== num_epochs = 10 print_interval = 50 # 每隔多少 step 打印一次 loss save_interval = 2 # 每隔多少 epoch 保存一次模型 # 一些示例常量(仅用于采样时的示例) VOXEL_SIZE = 64 # 假设体素大小是 64^3 EMPTY_BLOCK_ID = 0 # 假设 0 表示没有方块 for epoch in range(num_epochs): for step, batch in enumerate(dataloader): diffusion_model.train() voxel_encoder.train() voxel_decoder.train() text_encoder.train() # 如果不微调文本编码器,可改成 .eval() optimizer.zero_grad() # ============ 数据准备 ============ # batch["text"] 形如 [batch_size, ...],也有可能是一个 list[str] # 如果已经是 token_id,需要注意是否要转为 GPU # 这里假设 batch["text"] 是一个已经在 CPU 中的张量或 list[str] # 如果 batch["text"] 是一个 List[str],则直接传给 forward 即可: # text_embeddings = text_encoder(batch["text"]) # 如果 batch["text"] 是预先处理好的张量(IDs),需要自行处理: # text_embeddings = text_encoder.forward_with_ids(batch["text"]) # 需自行定义 # 下面仅给出一种可能的做法:假设 batch["text"] 是 List[str] raw_text = batch["text"] # Python list if isinstance(raw_text, torch.Tensor): # 如果本身是tensor (例如 token_ids),请自行处理 pass text_embeddings = text_encoder(raw_text) # [batch_size, 512] text_embeddings = text_embeddings.unsqueeze(1) # [batch_size, 1, 512] # 体素数据:假设 voxel shape=[batch_size, D, H, W] voxels = batch["voxels"] # 为了兼容 3D卷积,一般需要变成 [batch_size, 1, D, H, W] # 如果原来就有 channel 维度,请自行调整 if len(voxels.shape) == 4: # [batch_size, D, H, W] -> [batch_size, 1, D, H, W] voxels_input = voxels.unsqueeze(1).float() else: # 其他格式需要自己再做处理,比如 [x,y,z,block_id]等 raise ValueError("体素数据的shape与预期不符,请根据你的预处理进行修改。") voxels_input = voxels_input.to(device) # ============ 编码体素,得到潜向量 ============ encoded_voxels = voxel_encoder(voxels_input) # [batch_size, 512] # ============ 准备扩散模型所需的时间步 t ============ batch_size = encoded_voxels.shape[0] t = torch.randint(0, diffusion_model.num_timesteps, (batch_size,), device=device) # ============ 扩散模型前向计算 (x_start = encoded_voxels) ============ # 这里把 text_embeddings 当做 cond # diffusion_model.training_losses(...) 会返回一个包含各种损失值的字典 terms = diffusion_model.training_losses( model=diffusion_model.wrapped, # shap-e 内部实际用 wrapped 做推理 x_start=encoded_voxels, t=t, model_kwargs={"cond": text_embeddings} ) loss = terms["loss"].mean() # ============ 反向传播与优化 ============ loss.backward() optimizer.step() if step % print_interval == 0: print(f"[Epoch {epoch}/{num_epochs} | Step {step}] loss: {loss.item():.4f}") # ============ 每隔 save_interval 个 epoch 保存模型 ============ if (epoch + 1) % save_interval == 0: save_path = os.path.join(output_folder, f"epoch_{epoch+1}") os.makedirs(save_path, exist_ok=True) torch.save(diffusion_model.state_dict(), os.path.join(save_path, "diffusion_model.pt")) torch.save(voxel_encoder.state_dict(), os.path.join(save_path, "voxel_encoder.pt")) torch.save(voxel_decoder.state_dict(), os.path.join(save_path, "voxel_decoder.pt")) # 如果也训练了文本编码器,则保存: torch.save(text_encoder.state_dict(), os.path.join(save_path, "text_encoder.pt")) print(f"模型已保存到 {save_path}") # ============ (可选)在每个 epoch 后做一次采样测试 ============ diffusion_model.eval() voxel_encoder.eval() voxel_decoder.eval() text_encoder.eval() sample_text = ["a redstone torch"] # 示例文本 with torch.no_grad(): # 1) 得到文本条件 sample_text_embeddings = text_encoder(sample_text).to(device) # [1, 512] sample_text_embeddings = sample_text_embeddings.unsqueeze(1) # [1, 1, 512] # 2) 在潜扩散空间中进行采样 # Shap-E 的代码中有一个 sample_latents 或者类似的方法, # 不同版本可能略有差异,需要根据 shap-e 版本进行适配。 sampled_latents = diffusion_model.sample_latents( batch_size=1, model=diffusion_model.wrapped, diffusion=diffusion_model, model_kwargs={"cond": sample_text_embeddings}, guidance_scale=1.0, clip_denoised=False, use_fp16=False, use_karras=True, karras_steps=64, sigma_min=diffusion_model.sigma_min, sigma_max=diffusion_model.sigma_max, s_churn=0, device=device, progress=True, ) # [1, 512] # 3) 由 voxel_decoder 解码 # 获得每个体素位置的方块类别 logits # shape: [batch_size=1, n_block_types, 64, 64, 64] sampled_voxels_logits = voxel_decoder(sampled_latents) sampled_voxels_ids = torch.argmax(sampled_voxels_logits, dim=1) # [1, 64, 64, 64] sampled_voxels_ids = sampled_voxels_ids.squeeze(0).cpu().numpy() # [64, 64, 64] # 4) 将生成结果保存到文本文件(仅做简单演示) sample_voxel_text_lines = [] for x in range(VOXEL_SIZE): for y in range(VOXEL_SIZE): for z in range(VOXEL_SIZE): block_id = sampled_voxels_ids[x, y, z] if block_id != EMPTY_BLOCK_ID: sample_voxel_text_lines.append(f"{block_id} {x} {y} {z}") sample_voxel_text = "\n".join(sample_voxel_text_lines) sample_voxel_txt_path = os.path.join(output_folder, f"sample_epoch_{epoch}.txt") with open(sample_voxel_txt_path, 'w', encoding='utf-8') as txt_file: txt_file.write(sample_voxel_text) print(f"[Epoch {epoch}] 采样完成,体素数据已保存到:'{sample_voxel_txt_path}'") print("训练完成。")
数据加载与预处理
.pt
文件(即 descriptions.pt
和 voxels.pt
),并确保两者条目数量相同。[batch_size, 1, D, H, W]
;如果是其他格式,需要根据自己的预处理方式来改写 VoxelEncoder
。模型结构
MinecraftCLIPTextEncoder
:演示了如何使用预训练的 CLIP 文本模型,将输出投影到一个更小的维度(本例中为 512)。VoxelEncoder
和 VoxelDecoder
:仅为示例的 3D 卷积/反卷积网络,需要根据实际需求改动,比如更深层次的网络、Residual 模块、或使用其他 3D 表征。diffusion_model
:加载了 shap-e
内置的 text300M
配置,并手动修改了 input_channels
/ output_channels
等。你需要在本地先安装好 shap-e
并下载对应的模型权重(或自行训练、微调)。训练循环
t
:随机采样 [0, diffusion_model.num_timesteps)
,然后调用 training_losses
。terms = diffusion_model.training_losses(...)
返回字典,terms["loss"]
是默认的总损失。AdamW
,学习率 1e-4
仅作参考。实际需结合数据规模和 batch_size 等进行调参。采样与解码
diffusion_model.sample_latents(...)
在潜空间采样出潜向量,再通过 VoxelDecoder
解码为 [n_block_types, 64, 64, 64]
的 logits,并取 argmax
作为预测的方块 ID。其他注意事项
text_encoder.parameters()
也加入优化器参数列表,并在训练循环中把 with torch.no_grad():
去掉,确保梯度正常回传。batch_size
、缩减网络规模、或降低体素分辨率。torch.nn.parallel.DistributedDataParallel
。CLIPTokenizer
太大,或你的文本描述过长,需要注意截断、分词长度等问题。以上即为一个「相对完整」的微调示例代码框架,实际项目中通常还会加入更多的日志管理、指标计算、模型评估和可视化等模块。请根据自己的实际需求进行精细化改动和优化。祝你实验顺利!