Shap-E 代码深度分析以及微调策略 Shap-E 是 OpenAI 推出的一个能够根据文本描述...
बनाया गया: 18 जनवरी 2025
का उपयोग करके उत्तर दिया गया OpenAI 01 Pro द्वारा OpenAI01
बनाया गया: 18 जनवरी 2025
का उपयोग करके उत्तर दिया गया OpenAI 01 Pro द्वारा OpenAI01
Organizing the code
I’m gathering steps to create a comprehensive fine-tuning script. This involves data loading, model setup, training, embedding, encoding, feeding into the diffusion model, backpropagation, and decoding.
Incorporating necessary imports
I’m digging into importing essential libraries like PyTorch, transformers, and diffusion models. Considering steps for data reading, processing .jsonl files, and integrating them seamlessly into the training process.
Crafting a streamlined process
I’m breaking down the process into data preparation, dataset class, training script, and sampling code. Addressing pitfalls like undefined VOXEL_SIZE and EMPTY_BLOCK_ID with clear definitions.
Streamlining data handling
I’m detailing the steps to read, tokenize descriptions, and process voxel data. Saving outputs as tensors ensures a clear, organized approach, making the data preparation step more efficient.
Converting voxel data
I’m gathering info on translating voxel data into a 64^3 shape, which the VoxelEncoder needs. This includes setting EMPTY_BLOCK_ID and crafting a function for this transformation.
Wrapping up the task
I’m finalizing the training script, focusing on defining voxel transformation functions, setting constants, and ensuring tokenization aligns with the text encoder. Everything is being wrapped up.
Streamlining token handling
I’m examining whether to store raw text or token IDs in "descriptions.pt". This hinges on the text encoder's approach, aiming to avoid double tokenization and ensure consistent token handling.
Clearing the process
I’m pulling together a plan to store raw text in "descriptions.pt" and reformulating voxel data into a dense 64^3 shape for streamlined handling during re-tokenization by the text encoder.
Creating voxel representation
I’m piecing together a function to convert voxel data into a 64^3 volume, using a dictionary for block IDs and a sparse representation adapted by the user if necessary.
Shaping the script
I'm crafting a single, integrated code snippet that covers data preparation, dataset handling, training loop, and sampling approach, with a focus on managing device-specific terminology and dependencies.
思考,持续 2m 5s
下面给出一个相对完整、可运行的示例性微调脚本,演示从原始 JSONL 数据到生成最终体素结果的完整流程。需要注意的是,以下示例代码主要用于教学和参考,实际项目中仍需根据你自身的数据格式、超参数需求和硬件情况进行调整。
在使用此脚本前,请确保已经安装相关依赖:
pip install shap-e pip install transformers pip install torch torchvision
shap-e
:OpenAI 提供的 Shap-E 库transformers
:Hugging Face 的 Transformer 工具包torch
/ torchvision
:PyTorch假设你有一个包含 Minecraft 体素描述和坐标的 JSONL
文件(例如 training_data/10001-1.jsonl
),格式示例如下:
{"input": "一个红色的羊毛方块", "output": "red_wool 10 10 10"} {"input": "一个石头搭建的房子", "output": "stone 0 0 0\nstone 1 0 0\n..."} ...
其中:
"input"
为自然语言描述"output"
是若干行描述体素坐标信息,每行格式形如:方块名称 x y z
你还需要一个 blockids.txt
文件,里面定义了方块名称与数字 ID 的对应关系。例如:
block_ids = { 0: "air", 1: "stone", 2: "dirt", 3: "red_wool", ... }
注意:上面只是一个示例,你需要根据实际的 Minecraft 方块种类自行定义映射。
然后我们将把这些 JSONL 数据转换成两个 PyTorch 文件:
descriptions.pt
:保存的是原始文本字符串列表(或者批量)voxels.pt
:保存的是体素稀疏表示(或构造成密集 3D 张量后再保存)下面给出一个示例脚本,将 JSONL 转为可用的训练数据(稀疏体素列表)并保存下来。你可以单独保存成 prepare_data.py
或者直接放在你的主脚本里运行也行。
import os import json import torch from transformers import CLIPTokenizer # ========== 1. 定义路径 ========== training_data_path = 'training_data/10001-1.jsonl' # 你的 JSONL 文件 block_ids_path = 'blockids.txt' # 存放 block_ids 映射的 Python 脚本 output_folder = 'processed_output' # 输出文件夹 os.makedirs(output_folder, exist_ok=True) # ========== 2. 加载 block_ids ========== block_ids = {} with open(block_ids_path, 'r', encoding='utf-8') as f: exec(f.read()) # 这里会执行 block_ids = { ... } # ========== 3. 初始化 tokenizer ========== # 如果后续还要用 CLIPTokenizer 做分词,这里可以提前加载 tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") # ========== 4. 准备容器 ========== description_list = [] # 用于保存原始文本 voxel_list = [] # 用于保存[(block_id, x, y, z), (block_id, x, y, z), ...] # ========== 5. 读取 JSONL 文件 ========== with open(training_data_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line.strip()) description = data.get('input', '') # 读取自然语言描述 voxel_data = data.get('output', '') # 读取体素坐标 # 将描述直接存储为字符串(后面在模型训练阶段再进行分词) description_list.append(description) # 处理体素数据,每行格式: block_name x y z voxel_lines = voxel_data.strip().split('\n') voxel_ids = [] for voxel_line in voxel_lines: parts = voxel_line.strip().split() if len(parts) == 4: block_name, x, y, z = parts # 根据 block_ids 找到对应的数字 ID block_id = None for k, v in block_ids.items(): if v == block_name: block_id = k break if block_id is None: print(f"[警告] 未找到方块名称 '{block_name}' 的映射,跳过。") continue voxel_ids.append([int(block_id), int(x), int(y), int(z)]) else: print(f"[警告] 无效的体素行: {voxel_line}") # voxel_ids 是一个 (N, 4) 的列表 voxel_list.append(voxel_ids) # ========== 6. 保存数据 ========== torch.save(description_list, os.path.join(output_folder, 'descriptions.pt')) torch.save(voxel_list, os.path.join(output_folder, 'voxels.pt')) print("数据预处理完成,已生成 descriptions.pt 和 voxels.pt。")
在这里,我们将描述保存在一个 Python 列表(每个元素是一个字符串),体素信息保存在一个 Python 列表(每个元素又是 (N,4)
的坐标与 block_id 的稀疏形式)。
下面是简化的模型定义示例,包括:
MinecraftCLIPTextEncoder
VoxelEncoder
(将体素数据编码到潜变量)VoxelDecoder
(将潜变量解码回体素分布)shap_e.diffusion.gaussian_diffusion
中的 diffusion_from_config
在训练 VoxelEncoder
时,如果需要一个固定分辨率(如 64x64x64
)的 3D 网格,就需要把 (N,4)
的稀疏表示转换为一个 dense_tensor[64,64,64]
,其中每个体素存储其 block_id。
EMPTY_BLOCK_ID = 0
表示“空气”或“无方块”示例:
def voxel_list_to_dense(voxel_ids, size=64, empty_id=0): """ 将体素的稀疏表示(一堆 [block_id, x, y, z]) 转为 [size, size, size] 的密集张量。 """ volume = torch.full((size, size, size), empty_id, dtype=torch.long) for (block_id, vx, vy, vz) in voxel_ids: if 0 <= vx < size and 0 <= vy < size and 0 <= vz < size: volume[vx, vy, vz] = block_id return volume
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset from transformers import CLIPTokenizer, CLIPTextModel from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_config from shap_e.models.generation.perceiver import SimplePerceiver from shap_e.models.generation.transformer import Transformer # ========== 超参数定义 ========== VOXEL_SIZE = 64 # 体素分辨率 EMPTY_BLOCK_ID = 0 # 空气或空方块 NUM_BLOCK_TYPES = 600 # 假设你有 600 种不同方块,需自行根据 block_ids 配置 # ========== 1. 文本编码器 (CLIP) ========== class MinecraftCLIPTextEncoder(nn.Module): """ 使用 CLIP 文本模型作为底座,再接一个线性层,将输出映射到自定义维度。 如果你只想要 CLIP 的 pooler_output,可将 proj 层去掉或仅用于扩散条件。 """ def __init__(self, clip_model_name="openai/clip-vit-large-patch14", out_dim=512): super().__init__() self.clip = CLIPTextModel.from_pretrained(clip_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) self.proj = nn.Linear(self.clip.config.hidden_size, out_dim) def forward(self, text_list): """ text_list: list of strings (原始句子) """ # 将文本转为 token inputs = self.tokenizer(text_list, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.clip.device) for k, v in inputs.items()} outputs = self.clip(**inputs) # pooler_output 通常是 [batch_size, hidden_size] pooled = outputs.pooler_output # 或者 outputs.last_hidden_state 的某个池化 projected = self.proj(pooled) # [batch_size, out_dim] return projected # ========== 2. 体素编码器 ========== class VoxelEncoder(nn.Module): """ 将体素网格 (batch, 1, D, H, W) -> (batch, d_latent) 可以根据需要修改网络结构、通道数等。 """ def __init__(self, d_latent=512): super().__init__() self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1) self.fc = nn.Linear(128 * 8 * 8 * 8, d_latent) def forward(self, x): # x: [batch_size, 1, 64, 64, 64] x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.shape[0], -1) x = self.fc(x) return x # ========== 3. 体素解码器 ========== class VoxelDecoder(nn.Module): """ 将潜变量 (batch, d_latent) -> (batch, NUM_BLOCK_TYPES, D, H, W) 输出 logits,用 argmax 或 softmax 获取每个体素的 block_id。 """ def __init__(self, d_latent=512, output_size=64, n_channels=NUM_BLOCK_TYPES): super().__init__() self.fc = nn.Linear(d_latent, 128 * 8 * 8 * 8) self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose3d(32, n_channels, kernel_size=4, stride=2, padding=1) self.output_size = output_size def forward(self, x): x = self.fc(x) # (batch, 128*8*8*8) x = x.view(x.shape[0], 128, 8, 8, 8) # (batch, 128, 8, 8, 8) x = F.relu(self.deconv1(x)) # -> (batch, 64, 16, 16, 16) x = F.relu(self.deconv2(x)) # -> (batch, 32, 32, 32, 32) x = self.deconv3(x) # -> (batch, n_channels, 64, 64, 64) return x # logits
descriptions.pt
和 voxels.pt
(block_id, x, y, z)
稀疏表示转换为 64x64x64
的 dense 张量VoxelEncoder
,并对照文本嵌入训练扩散模型下面是示例训练循环代码,你可以放到同一个 .py
文件中,也可以拆开多个文件组织。
################################ # 载入并构造 DataLoader ################################ class MinecraftDataset(Dataset): """ 将 descriptions.pt 与 voxels.pt 拼起来,并在 __getitem__ 中做稀疏->稠密转换。 """ def __init__(self, descriptions_path, voxels_path, voxel_size=64): super().__init__() self.descriptions = torch.load(descriptions_path) # list of strings self.voxels_sparse = torch.load(voxels_path) # list of (N, 4) self.voxel_size = voxel_size assert len(self.descriptions) == len(self.voxels_sparse), \ "描述与体素数量不一致,检查数据。" def __len__(self): return len(self.descriptions) def voxel_list_to_dense(self, voxel_ids): volume = torch.full((self.voxel_size, self.voxel_size, self.voxel_size), EMPTY_BLOCK_ID, dtype=torch.long) for block_id, x, y, z in voxel_ids: if 0 <= x < self.voxel_size and 0 <= y < self.voxel_size and 0 <= z < self.voxel_size: volume[x, y, z] = block_id return volume def __getitem__(self, idx): desc = self.descriptions[idx] # 文本描述 (string) voxel_ids = self.voxels_sparse[idx] # 稀疏体素坐标列表 # 稀疏 -> 稠密 voxel_dense = self.voxel_list_to_dense(voxel_ids) return { "text": desc, "voxels": voxel_dense # shape [64, 64, 64] } ################################ # 主训练脚本 ################################ def train_model(): import os import torch from torch.utils.data import DataLoader from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_config # ========== 超参数 & 路径 ========== output_folder = 'processed_output' descriptions_path = os.path.join(output_folder, 'descriptions.pt') voxels_path = os.path.join(output_folder, 'voxels.pt') batch_size = 2 num_epochs = 10 # ========== 1. 加载数据集 ========== dataset = MinecraftDataset(descriptions_path, voxels_path, voxel_size=VOXEL_SIZE) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # ========== 2. 初始化模型 ========== text_encoder = MinecraftCLIPTextEncoder(out_dim=512).to(device) voxel_encoder = VoxelEncoder(d_latent=512).to(device) voxel_decoder = VoxelDecoder(d_latent=512).to(device) # ========== 3. 加载并配置扩散模型 ========== # 这里以 shap-e 内部的 "text300M" 模型配置为例,你需要根据需求修改 config = load_config("text300M", device=device) # 根据我们自定义的大小做适配 config["inner"]["input_channels"] = 512 config["inner"]["n_ctx"] = 1 config["inner"]["output_channels"] = 512 * 2 diffusion_model = diffusion_from_config(config).to(device) # ========== 4. 优化器配置 ========== # 如果想让 text_encoder 一起训练,可以把它的参数也加进来 params = list(diffusion_model.parameters()) params += list(voxel_encoder.parameters()) params += list(voxel_decoder.parameters()) # 也可以加:params += list(text_encoder.parameters()) optimizer = optim.AdamW(params, lr=1e-4) # ========== 5. 训练循环 ========== for epoch in range(num_epochs): for step, batch in enumerate(dataloader): optimizer.zero_grad() # 取文本和体素 text_list = batch["text"] # list of strings voxels_dense = batch["voxels"] # [batch_size, 64, 64, 64] # 转到 GPU voxels_dense = voxels_dense.to(device) # (batch, 64, 64, 64) -> (batch, 1, 64, 64, 64) voxels_input = voxels_dense.unsqueeze(1).float() batch_size_now = voxels_input.shape[0] # 生成时间步 t t = torch.randint(0, diffusion_model.num_timesteps, (batch_size_now,), device=device) # 编码体素 encoded_voxels = voxel_encoder(voxels_input) # [batch_size, 512] # 编码文本 text_embeddings = text_encoder(text_list) # [batch_size, 512] # 在这里如果扩散模型需要 (batch_size, 1, hidden_dim) 格式,可扩展一维 text_embeddings = text_embeddings.unsqueeze(1) # [batch_size, 1, 512] # 计算扩散模型训练损失 terms = diffusion_model.training_losses( model=diffusion_model.wrapped, x_start=encoded_voxels, # 初始潜变量 t=t, model_kwargs={"cond": text_embeddings} ) loss = terms["loss"].mean() loss.backward() optimizer.step() if step % 10 == 0: print(f"Epoch {epoch}, Step {step}, Loss: {loss.item():.4f}") # ========== 每个 epoch 结束后,可选地保存模型 & 采样 ========== torch.save(diffusion_model.state_dict(), os.path.join(output_folder, f'diffusion_model_epoch_{epoch}.pt')) torch.save(voxel_encoder.state_dict(), os.path.join(output_folder, f'voxel_encoder_epoch_{epoch}.pt')) torch.save(voxel_decoder.state_dict(), os.path.join(output_folder, f'voxel_decoder_epoch_{epoch}.pt')) torch.save(text_encoder.state_dict(), os.path.join(output_folder, f'text_encoder_epoch_{epoch}.pt')) # 在此进行一次采样测试 sample_texts = ["a redstone torch"] # 你想要测试的文本 with torch.no_grad(): sample_text_embeddings = text_encoder(sample_texts).to(device) # [1, 512] sample_text_embeddings = sample_text_embeddings.unsqueeze(1) # [1, 1, 512] # 扩散采样 latent 向量 sampled_latents = diffusion_model.sample_latents( batch_size=1, model=diffusion_model.wrapped, diffusion=diffusion_model, model_kwargs={"cond": sample_text_embeddings}, guidance_scale=1.0, # 超参数,可尝试调整 clip_denoised=False, use_fp16=False, use_karras=True, karras_steps=64, sigma_min=diffusion_model.sigma_min, sigma_max=diffusion_model.sigma_max, s_churn=0, device=device, progress=True ) # 解码体素 sampled_voxels_logits = voxel_decoder(sampled_latents) # [1, 600, 64, 64, 64] sampled_voxels = torch.argmax(sampled_voxels_logits, dim=1) # [1, 64, 64, 64] sampled_voxels = sampled_voxels.squeeze(0).cpu().numpy() # [64, 64, 64] # 将生成的体素保存到文本文件 sample_voxel_text_lines = [] for x in range(VOXEL_SIZE): for y in range(VOXEL_SIZE): for z in range(VOXEL_SIZE): block_id = sampled_voxels[x, y, z] if block_id != EMPTY_BLOCK_ID: sample_voxel_text_lines.append(f"{block_id} {x} {y} {z}") sample_voxel_text = '\n'.join(sample_voxel_text_lines) sample_voxel_txt_path = os.path.join(output_folder, f'sample_epoch_{epoch}.txt') with open(sample_voxel_txt_path, 'w', encoding='utf-8') as txt_file: txt_file.write(sample_voxel_text) print(f"[采样] Epoch {epoch} - 体素数据已保存到: {sample_voxel_txt_path}") print("训练完成。") if __name__ == "__main__": train_model()
上面的示例脚本在每个 epoch 结束后,都会对一段示例文本(如 "a redstone torch"
)进行采样,然后:
sampled_latents
voxel_decoder
得到体素分布 logitsargmax
作为最终的方块 IDEMPTY_BLOCK_ID
的坐标写回 txt
文件保存。如果需要可视化或导入 Minecraft,你可以在生成的 .txt
文件基础上做进一步转换,比如:
.schematic
或者其他编辑器可识别的格式为了方便使用,这里将数据预处理和训练都整合在一个脚本中(你也可以拆分为两个文件)。请根据你的实际需求修改路径、超参数、方块 ID 映射等内容。
#!/usr/bin/env python # -*- coding: utf-8 -*- import os import json import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, Dataset from transformers import CLIPTokenizer, CLIPTextModel from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_config from shap_e.models.generation.perceiver import SimplePerceiver from shap_e.models.generation.transformer import Transformer ################################ # 0. 超参数和全局变量 ################################ VOXEL_SIZE = 64 # 体素网格尺寸 EMPTY_BLOCK_ID = 0 # 空方块 ID NUM_BLOCK_TYPES = 600 # 你需要根据实际 block_ids 数量修改 training_data_path = 'training_data/10001-1.jsonl' # JSONL 数据 block_ids_path = 'blockids.txt' # block_ids 映射脚本 output_folder = 'processed_output' os.makedirs(output_folder, exist_ok=True) ################################ # 1. 预处理数据 (JSONL -> .pt) ################################ def prepare_data(): # 读取 block_ids block_ids = {} with open(block_ids_path, 'r', encoding='utf-8') as f: exec(f.read()) # 例如: block_ids = {0: "air", 1: "stone", ...} # 这里仅示例如何保存原始文本,不做提前tokenize description_list = [] voxel_list = [] with open(training_data_path, 'r', encoding='utf-8') as f: for line in f: data = json.loads(line.strip()) desc = data.get('input', '') voxel_data = data.get('output', '') # 1) 保存描述 description_list.append(desc) # 2) 解析体素 lines = voxel_data.strip().split('\n') sparse_voxels = [] for l in lines: parts = l.strip().split() if len(parts) == 4: block_name, x, y, z = parts # 根据名称找 ID b_id = None for k, v in block_ids.items(): if v == block_name: b_id = k break if b_id is not None: sparse_voxels.append([int(b_id), int(x), int(y), int(z)]) else: print(f"[警告] 未找到 block_name='{block_name}' 的映射,已跳过。") else: print(f"[警告] 无效体素行: {l}") voxel_list.append(sparse_voxels) torch.save(description_list, os.path.join(output_folder, 'descriptions.pt')) torch.save(voxel_list, os.path.join(output_folder, 'voxels.pt')) print("数据预处理完成。") ################################ # 2. 定义数据集和模型 ################################ class MinecraftDataset(Dataset): def __init__(self, descriptions_path, voxels_path, voxel_size=64): super().__init__() self.descriptions = torch.load(descriptions_path) # list of str self.voxels_sparse = torch.load(voxels_path) # list of (N,4) self.voxel_size = voxel_size assert len(self.descriptions) == len(self.voxels_sparse), \ "描述与体素数量不一致。" def __len__(self): return len(self.descriptions) def voxel_list_to_dense(self, voxel_ids): volume = torch.full((self.voxel_size, self.voxel_size, self.voxel_size), EMPTY_BLOCK_ID, dtype=torch.long) for (block_id, x, y, z) in voxel_ids: if 0 <= x < self.voxel_size and 0 <= y < self.voxel_size and 0 <= z < self.voxel_size: volume[x, y, z] = block_id return volume def __getitem__(self, idx): desc = self.descriptions[idx] sparse_voxels = self.voxels_sparse[idx] dense_voxel = self.voxel_list_to_dense(sparse_voxels) return { "text": desc, "voxels": dense_voxel } class MinecraftCLIPTextEncoder(nn.Module): def __init__(self, clip_model_name="openai/clip-vit-large-patch14", out_dim=512): super().__init__() self.clip = CLIPTextModel.from_pretrained(clip_model_name) self.tokenizer = CLIPTokenizer.from_pretrained(clip_model_name) self.proj = nn.Linear(self.clip.config.hidden_size, out_dim) def forward(self, text_list): inputs = self.tokenizer(text_list, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(self.clip.device) for k, v in inputs.items()} outputs = self.clip(**inputs) pooled = outputs.pooler_output return self.proj(pooled) class VoxelEncoder(nn.Module): def __init__(self, d_latent=512): super().__init__() self.conv1 = nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1) self.conv2 = nn.Conv3d(32, 64, kernel_size=4, stride=2, padding=1) self.conv3 = nn.Conv3d(64, 128, kernel_size=4, stride=2, padding=1) # 注意卷积堆叠后空间会缩小到 8x8x8(64/2/2/2 = 8) self.fc = nn.Linear(128 * 8 * 8 * 8, d_latent) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) x = self.fc(x) return x class VoxelDecoder(nn.Module): def __init__(self, d_latent=512, output_size=64, n_channels=NUM_BLOCK_TYPES): super().__init__() self.fc = nn.Linear(d_latent, 128 * 8 * 8 * 8) self.deconv1 = nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, padding=1) self.deconv2 = nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, padding=1) self.deconv3 = nn.ConvTranspose3d(32, n_channels, kernel_size=4, stride=2, padding=1) def forward(self, x): x = self.fc(x) x = x.view(x.shape[0], 128, 8, 8, 8) x = F.relu(self.deconv1(x)) x = F.relu(self.deconv2(x)) x = self.deconv3(x) # (batch, n_channels, 64, 64, 64) return x ################################ # 3. 训练主函数 ################################ def train_model(): # 1) 先构造 DataLoader dataset = MinecraftDataset( os.path.join(output_folder, 'descriptions.pt'), os.path.join(output_folder, 'voxels.pt'), voxel_size=VOXEL_SIZE ) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) # 2) 初始化设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("[Info] Using device:", device) # 3) 初始化模型 text_encoder = MinecraftCLIPTextEncoder(out_dim=512).to(device) voxel_encoder = VoxelEncoder(d_latent=512).to(device) voxel_decoder = VoxelDecoder(d_latent=512).to(device) # 4) 加载扩散模型配置 config = load_config("text300M", device=device) config["inner"]["input_channels"] = 512 config["inner"]["n_ctx"] = 1 config["inner"]["output_channels"] = 512 * 2 diffusion_model = diffusion_from_config(config).to(device) # 5) 优化器 params = list(diffusion_model.parameters()) \ + list(voxel_encoder.parameters()) \ + list(voxel_decoder.parameters()) # 如果需要训练文本编码器则一起加上: params += list(text_encoder.parameters()) optimizer = optim.AdamW(params, lr=1e-4) num_epochs = 5 for epoch in range(num_epochs): for step, batch in enumerate(dataloader): optimizer.zero_grad() text_list = batch["text"] voxels_dense = batch["voxels"].to(device) # (B, 64, 64, 64) voxels_dense = voxels_dense.unsqueeze(1).float() # (B,1,64,64,64) B = voxels_dense.size(0) t = torch.randint(0, diffusion_model.num_timesteps, (B,), device=device) # 编码 encoded_voxels = voxel_encoder(voxels_dense) # (B, 512) text_emb = text_encoder(text_list) # (B, 512) text_emb = text_emb.unsqueeze(1) # (B,1,512) # 扩散损失 terms = diffusion_model.training_losses( model=diffusion_model.wrapped, x_start=encoded_voxels, t=t, model_kwargs={"cond": text_emb} ) loss = terms["loss"].mean() loss.backward() optimizer.step() if step % 10 == 0: print(f"[Epoch {epoch}/{num_epochs}] Step {step}, Loss={loss.item():.4f}") # ========== 每个 epoch 做一次保存和采样 ========== torch.save(diffusion_model.state_dict(), os.path.join(output_folder, f'diffusion_model_e{epoch}.pt')) torch.save(voxel_encoder.state_dict(), os.path.join(output_folder, f'voxel_encoder_e{epoch}.pt')) torch.save(voxel_decoder.state_dict(), os.path.join(output_folder, f'voxel_decoder_e{epoch}.pt')) torch.save(text_encoder.state_dict(), os.path.join(output_folder, f'text_encoder_e{epoch}.pt')) # 测试采样 with torch.no_grad(): sample_text = ["a redstone torch"] sample_emb = text_encoder(sample_text).to(device) # (1,512) sample_emb = sample_emb.unsqueeze(1) # (1,1,512) # 采样 latent sampled_latents = diffusion_model.sample_latents( batch_size=1, model=diffusion_model.wrapped, diffusion=diffusion_model, model_kwargs={"cond": sample_emb}, guidance_scale=1.0, clip_denoised=False, use_fp16=False, use_karras=True, karras_steps=64, sigma_min=diffusion_model.sigma_min, sigma_max=diffusion_model.sigma_max, s_churn=0, device=device, progress=True ) # 解码体素 logits = voxel_decoder(sampled_latents) # (1, NUM_BLOCK_TYPES,64,64,64) pred_blocks = torch.argmax(logits, dim=1)[0] # (64,64,64) # 输出到txt lines = [] arr = pred_blocks.cpu().numpy() for x in range(VOXEL_SIZE): for y in range(VOXEL_SIZE): for z in range(VOXEL_SIZE): bid = arr[x, y, z] if bid != EMPTY_BLOCK_ID: lines.append(f"{bid} {x} {y} {z}") save_path = os.path.join(output_folder, f"sample_e{epoch}.txt") with open(save_path, "w", encoding="utf-8") as fw: fw.write("\n".join(lines)) print(f"[采样] 生成体素已保存至 {save_path}") print("[Info] 训练全部完成。") if __name__ == "__main__": # 第一步:预处理数据 prepare_data() # 第二步:训练模型 train_model()
运行本脚本后,将依次完成:
JSONL
读取并解析为 descriptions.pt
& voxels.pt
.txt
文件注意:脚本中关于 方块种类(
NUM_BLOCK_TYPES
)、体素分辨率、网络结构、超参数(学习率、batch size 等) 都是示例值,需要你根据实际数据规模和硬件资源进行修改与试验。
通过以上示例脚本,你就可以微调(或训练)一个 Shap-E 风格的“文本->隐式->体素”生成流程。其中有几个关键要点需要在实际项目中关注:
guidance_scale
、use_karras
等超参数,都可能影响最终生成质量和效率。以上示例可作为一个“端到端”的起点,后续你可以根据具体需求自定义或改进各个模块。祝你在将 Shap-E 应用于 Minecraft 体素生成的过程中一切顺利!