帮我改一下,原项目是预测方块位置推理下一个方块位置,我不要这样了,你弄成给它方块位置,预测方块id,...

बनाया गया: 2 मई 2025

का उपयोग करके उत्तर दिया गया O3 द्वारा Chat01

प्रश्न

帮我改一下,原项目是预测方块位置推理下一个方块位置,我不要这样了,你弄成给它方块位置,预测方块id,也就是把所有建筑物的房子的方块位置给它推理出每一个方块的id,比如石头,木头,听懂了吗,然后换成dit架构,用torchlighting训练,修改好后发送单个py文件给我,这个py文件包含了训练和推理,每训练一步都会显示损失函数和平均损失函数,每训练一个轮回都会推理一次,你可以生成一个不带方块id的火柴盒建筑让它推理每个方块的id,如果这个时候有个别方块是有id的,它就要预测附近方块的id,你可以放个别有id的方块给它推理,这是考验它的上下文能力,修改好发送完整代码,让用户可以在命令终端设置batchsize模型就保存到logs目录,写完代码记得告诉我怎么用,这是项目代码

Code of Conduct

import os

定义输出文件

output_file = 'output.txt'

定义支持的文件扩展名(元组结构)

SUPPORTED_EXTENSIONS = ('.py', '.md', '.js',)

with open(output_file, 'w', encoding='utf-8') as outfile:
    for root, dirs, files in os.walk('.'):
        for file in files:
            # 使用元组检查文件扩展名(更规范的写法)
            if file.endswith(SUPPORTED_EXTENSIONS):
                file_path = os.path.join(root, file)
                outfile.write(f'文件路径: {file_path}\n')
                
                try:
                    with open(file_path, 'r', encoding='utf-8') as infile:
                        # 添加空行分隔路径和内容
                        outfile.write('\n' + infile.read() + '\n')
                        outfile.write('\n' + '=' * 80 + '\n\n')
                except UnicodeDecodeError:
                    # 特殊处理二进制文件误读的情况
                    outfile.write(f'警告: {file_path} 可能是二进制文件,跳过读取\n\n')
                except Exception as e:
                    outfile.write(f'读取文件 {file_path} 时出错: {str(e)}\n\n')

print(f'合并完成,结果保存在 {output_file} 中。')

================================================================================

文件路径: .\main.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import argparse
import random
import warnings
from datetime import datetime
from os import path as osp
from time import time as tic

import numpy as np
import torch
from torch import optim
from torch.utils.data import DataLoader
from voxelcnn.checkpoint import Checkpointer
from voxelcnn.criterions import CrossEntropyLoss
from voxelcnn.datasets import Craft3DDataset
from voxelcnn.evaluators import CCA, MTC, Accuracy
from voxelcnn.models import VoxelCNN
from voxelcnn.summary import Summary
from voxelcnn.utils import Section, collate_batches, setup_logger, to_cuda

def global_setup(args):
    if args.seed is not None:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True
    if not args.cpu_only:
        if not torch.cuda.is_available():
            warnings.warn("CUDA is not available. Fallback to using CPU only")
            args.cpu_only = True
        else:
            torch.cuda.benchmark = True

def build_data_loaders(args, logger):
    data_loaders = {}
    for subset in ("train", "val", "test"):
        dataset = Craft3DDataset(
            args.data_dir,
            subset,
            max_samples=args.max_samples,
            next_steps=10,
            logger=logger,
        )
        data_loaders[subset] = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=subset == "train",
            num_workers=args.num_workers,
            pin_memory=not args.cpu_only,
        )
    return data_loaders

def build_model(args, logger):
    model = VoxelCNN()
    if not args.cpu_only:
        model.cuda()
    logger.info("Model architecture:\n" + str(model))
    return model

def build_criterion(args):
    criterion = CrossEntropyLoss()
    if not args.cpu_only:
        criterion.cuda()
    return criterion

def build_optimizer(args, model):
    no_decay = []
    decay = []
    for name, param in model.named_parameters():
        if name.endswith(".bias"):
            no_decay.append(param)
        else:
            decay.append(param)
    params = [{"params": no_decay, "weight_decay": 0}, {"params": decay}]
    return optim.SGD(
        params,
        lr=args.lr,
        weight_decay=args.weight_decay,
        momentum=args.momentum,
        nesterov=True,
    )

def build_scheduler(args, optimizer):
    return optim.lr_scheduler.StepLR(
        optimizer, step_size=args.step_size, gamma=args.gamma
    )

def build_evaluators(args):
    return {
        "acc@1": Accuracy(next_steps=1),
        "acc@5": Accuracy(next_steps=5),
        "acc@10": Accuracy(next_steps=10),
    }

def train(
    args, epoch, data_loader, model, criterion, optimizer, scheduler, evaluators, logger
):
    summary = Summary(logger=logger)
    model.train()
    timestamp = tic()
    for i, (inputs, targets) in enumerate(data_loader):
        times = {"data": tic() - timestamp}
        if not args.cpu_only:
            inputs = to_cuda(inputs)
            targets = to_cuda(targets)
        outputs = model(inputs)
        losses = criterion(outputs, targets)
        with torch.no_grad():
            metrics = {k: float(v(outputs, targets)) for k, v in evaluators.items()}

        optimizer.zero_grad()
        losses["overall_loss"].backward()
        optimizer.step()
        try:
            lr = scheduler.get_last_lr()[0]
        except Exception:
            # For backward compatibility
            lr = scheduler.get_lr()[0]

        times["time"] = tic() - timestamp
        summary.add(times=times, lrs={"lr": lr}, losses=losses, metrics=metrics)
        summary.print_current(
            prefix=f"[{epoch}/{args.num_epochs}][{i + 1}/{len(data_loader)}]"
        )
        timestamp = tic()
    scheduler.step()

@torch.no_grad()
def evaluate(args, epoch, data_loader, model, evaluators, logger):
    summary = Summary(logger=logger)
    model.eval()
    timestamp = tic()
    batch_results = []
    for i, (inputs, targets) in enumerate(data_loader):
        times = {"data": tic() - timestamp}
        if not args.cpu_only:
            inputs = to_cuda(inputs)
            targets = to_cuda(targets)
        outputs = model(inputs)
        batch_results.append(
            {k: v.step(outputs, targets) for k, v in evaluators.items()}
        )

        times["time"] = tic() - timestamp
        summary.add(times=times)
        summary.print_current(
            prefix=f"[{epoch}/{args.num_epochs}][{i + 1}/{len(data_loader)}]"
        )
        timestamp = tic()
    results = collate_batches(batch_results)
    metrics = {k: float(v.stop(results[k])) for k, v in evaluators.items()}
    return metrics

def main(args):
    # Set log file name based on current date and time
    cur_datetime = datetime.now().strftime("%Y%m%d.%H%M%S")
    log_path = osp.join(args.save_dir, f"log.{cur_datetime}.txt")
    logger = setup_logger(save_file=log_path)
    logger.info(f"Save logs to: {log_path}")

    with Section("Global setup", logger=logger):
        global_setup(args)

    with Section("Building data loaders", logger=logger):
        data_loaders = build_data_loaders(args, logger)

    with Section("Building model", logger=logger):
        model = build_model(args, logger)

    with Section("Building criterions, optimizer, scheduler", logger=logger):
        criterion = build_criterion(args)
        optimizer = build_optimizer(args, model)
        scheduler = build_scheduler(args, optimizer)

    with Section("Building evaluators", logger=logger):
        evaluators = build_evaluators(args)

    checkpointer = Checkpointer(args.save_dir)
    last_epoch = 0
    if args.resume is not None:
        with Section(f"Resuming from model: {args.resume}", logger=logger):
            last_epoch = checkpointer.resume(
                args.resume, model=model, optimizer=optimizer, scheduler=scheduler
            )

    for epoch in range(last_epoch + 1, args.num_epochs + 1):
        with Section(f"Training epoch {epoch}", logger=logger):
            train(
                args,
                epoch,
                data_loaders["train"],
                model,
                criterion,
                optimizer,
                scheduler,
                evaluators,
                logger,
            )
        with Section(f"Validating epoch {epoch}", logger=logger):
            # Evaluate on the validation set by the lightweight accuracy metrics
            metrics = evaluate(
                args, epoch, data_loaders["val"], model, evaluators, logger
            )
            # Use acc@10 as the key metric to select best model
            checkpointer.save(model, optimizer, scheduler, epoch, metrics["acc@10"])
            metrics_str = "  ".join(f"{k}: {v:.3f}" for k, v in metrics.items())
            best_mark = "*" if epoch == checkpointer.best_epoch else ""
            logger.info(f"Finish  epoch: {epoch}  {metrics_str} {best_mark}")

    best_epoch = checkpointer.best_epoch
    with Section(f"Final test with best model from epoch: {best_epoch}", logger=logger):
        # Load the best model and evaluate all the metrics on the test set
        checkpointer.load("best", model=model)
        metrics = evaluate(
            args, best_epoch, data_loaders["test"], model, evaluators, logger
        )

        # Additional evaluation metrics. Takes quite long time to evaluate
        dataset = data_loaders["test"].dataset
        params = {
            "local_size": dataset.local_size,
            "global_size": dataset.global_size,
            "history": dataset.history,
        }
        metrics.update(CCA(**params).evaluate(dataset, model))
        metrics.update(MTC(**params).evaluate(dataset, model))

        metrics_str = "  ".join(f"{k}: {v:.3f}" for k, v in metrics.items())
        logger.info(f"Final test from best epoch: {best_epoch}\n{metrics_str}")

if name == "main":
    work_dir = osp.dirname(osp.abspath(file))
    parser = argparse.ArgumentParser(
        description="Train and evaluate VoxelCNN model on 3D-Craft dataset"
    )
    # Data
    parser.add_argument(
        "--data_dir",
        type=str,
        default=osp.join(work_dir, "data"),
        help="Path to the data directory",
    )
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=16,
        help="Number of workers for preprocessing",
    )
    parser.add_argument(
        "--max_samples",
        type=int,
        default=None,
        help="When debugging, set this option to limit the number of training samples",
    )
    # Optimizer
    parser.add_argument("--lr", type=float, default=0.1, help="Initial learning rate")
    parser.add_argument(
        "--weight_decay", type=float, default=0.0001, help="Weight decay"
    )
    parser.add_argument("--momentum", type=float, default=0.9, help="Momentum")
    # Scheduler
    parser.add_argument("--step_size", type=int, default=5, help="StepLR step size")
    parser.add_argument("--gamma", type=int, default=0.1, help="StepLR gamma")
    parser.add_argument("--num_epochs", type=int, default=12, help="Total train epochs")
    # Misc
    parser.add_argument(
        "--save_dir",
        type=str,
        default=osp.join(work_dir, "logs"),
        help="Path to a directory to save log file and checkpoints",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="'latest' | 'best' | '<epoch number>' | '<path to a checkpoint>'. "
        "Default: None, will not resume",
    )
    parser.add_argument("--cpu_only", action="store_true", help="Only using CPU")
    parser.add_argument("--seed", type=int, default=None, help="Random seed")
    main(parser.parse_args())

================================================================================

文件路径: .\README.md

VoxelCNN

VoxelCNN is an order-aware generative model for building houses in Minecraft. This codebase is a PyTorch implementation of the training and evaluation pipeline for VoxelCNN.
VoxelCNN

VoxelCNN is trained and evaluated with the 3D-Craft dataset.
3D-Craft Dataset

For more details, please refer to the ICCV 19' paper Order-Aware Generative Modeling Using the 3D-Craft Dataset.

Installation

Python version >= 3.7 is required.

0. Clone the repo

text
git clone https://github.com/facebookresearch/VoxelCNN cd VoxelCNN

1. (Optional) Create a virtual environment

text
pip install virtualenv virtualenv voxelcnn_venv source voxelcnn_venv/bin/activate

2. Install PyTorch

Please follow the official installation guide to install PyTorch version >= 1.3.

3. Install other dependencies

pip install numpy requests tqdm

4. (Optional) Verify installation by running unit tests

python -m unittest discover -s test -v

Training and Evaluation

The 3D-Craft dataset will be downloaded automatically when launching the training for the first time.

Run the fast training (fewer epochs, slightly worse results):

python main.py --num_epochs 3 --step_size 1 --save_dir /path/to/save/log/and/checkpoints

Example final test results for the fast training:

text
acc@1: 0.622  acc@5: 0.760  acc@10: 0.788 cca_10%: 13.374  cca_25%: 11.115  cca_50%: 12.546  cca_75%: 12.564  cca_90%: 7.632  cca_avg: 11.446 mtc: 131.411  mtc_normed: 0.241

Run the full training:

python main.py --save_dir /path/to/save/log/and/checkpoints

Example final test results for the full training:

text
acc@1: 0.640  acc@5: 0.778  acc@10: 0.806 cca_10%: 13.630  cca_25%: 12.223  cca_50%: 13.168  cca_75%: 13.047  cca_90%: 7.571  cca_avg: 11.928 mtc: 121.753  mtc_normed: 0.223

License

VoxelCNN is released under the CC-BY-NC 4.0 license.

Citing

text
@inproceedings{zchen2019,   title = {Order-Aware Generative Modeling Using the 3D-Craft Dataset},   author = {Chen, Zhuoyuan and Guo, Demi and Xiao, Tong and Xie, Saining and Chen, Xinlei and Yu, Haonan and Gray, Jonathan and Srinet, Kavya and Fan, Haoqi and Ma, Jerry and Qi, Charles R and Tulsiani, Shubham and Szlam, Arthur and Zitnick, C. Lawrence},   booktitle = {ICCV},   year = {2019}, }

================================================================================

文件路径: .\test\test_criterions.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import math
import unittest

import torch
from voxelcnn.criterions import CrossEntropyLoss

class TestCrossEntropyLoss(unittest.TestCase):
    def test_forward(self):
        targets = {
            "coords": torch.tensor([[0], [13]]),
            "types": torch.tensor([[0], [1]]),
        }

        coords_outputs = torch.zeros((2, 1, 3, 3, 3))
        coords_outputs[0, 0, 0, 0, 0] = 1.0
        coords_outputs[1, 0, 1, 1, 1] = 1.0

        types_outputs = torch.zeros((2, 2, 3, 3, 3))
        types_outputs[0, 0, 0, 0, 0] = 1.0
        types_outputs[1, 1, 1, 1, 1] = 1.0
        outputs = {"coords": coords_outputs, "types": types_outputs}

        criterion = CrossEntropyLoss()
        losses = criterion(outputs, targets)

        p_coords = math.exp(1.0) / (math.exp(1.0) + 26)
        p_types = math.exp(1.0) / (math.exp(1.0) + 1)
        self.assertAlmostEqual(
            float(losses["coords_loss"]), -math.log(p_coords), places=3
        )
        self.assertAlmostEqual(
            float(losses["types_loss"]), -math.log(p_types), places=3
        )

if name == "main":
    unittest.main()

================================================================================

文件路径: .\test\test_datasets.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import unittest

import torch
from voxelcnn.datasets import Craft3DDataset

class TestCraft3DDataset(unittest.TestCase):
    def setUp(self):
        self.annotation = torch.tensor(
            [[0, 0, 0, 0], [3, 3, 3, 3], [1, 1, 1, 1], [2, 2, 2, 2]]
        )

    def test_convert_to_voxels(self):
        voxels = Craft3DDataset._convert_to_voxels(
            self.annotation, size=3, occupancy_only=False
        )
        self.assertEqual(voxels.shape, (256, 3, 3, 3))
        self.assertEqual(voxels[1, 0, 0, 0], 1)
        self.assertEqual(voxels[2, 1, 1, 1], 1)
        self.assertEqual(voxels[3, 2, 2, 2], 1)

        voxels = Craft3DDataset._convert_to_voxels(
            self.annotation, size=5, occupancy_only=True
        )
        self.assertEqual(voxels.shape, (1, 5, 5, 5))
        self.assertEqual(voxels[0, 0, 0, 0], 1)
        self.assertEqual(voxels[0, 1, 1, 1], 1)
        self.assertEqual(voxels[0, 2, 2, 2], 1)
        self.assertEqual(voxels[0, 3, 3, 3], 1)

    def test_prepare_inputs(self):
        inputs = Craft3DDataset.prepare_inputs(
            self.annotation, local_size=3, global_size=5, history=2
        )
        self.assertEqual(set(inputs.keys()), {"local", "global", "center"})
        self.assertTrue(torch.all(inputs["center"] == torch.tensor([2, 2, 2])))
        self.assertEqual(inputs["local"].shape, (512, 3, 3, 3))
        self.assertEqual(inputs["local"][1, 0, 0, 0], 1)
        self.assertEqual(inputs["local"][2, 1, 1, 1], 1)
        self.assertEqual(inputs["local"][3, 2, 2, 2], 1)
        self.assertEqual(inputs["local"][257, 0, 0, 0], 1)
        self.assertEqual(inputs["local"][258, 1, 1, 1], 0)
        self.assertEqual(inputs["local"][259, 2, 2, 2], 1)
        self.assertEqual(inputs["global"].shape, (1, 5, 5, 5))
        self.assertEqual(inputs["global"][0, 0, 0, 0], 1)
        self.assertEqual(inputs["global"][0, 1, 1, 1], 1)
        self.assertEqual(inputs["global"][0, 2, 2, 2], 1)
        self.assertEqual(inputs["global"][0, 3, 3, 3], 1)

    def test_prepare_targets(self):
        targets = Craft3DDataset.prepare_targets(
            self.annotation, next_steps=2, local_size=3
        )
        self.assertEqual(set(targets.keys()), {"coords", "types"})
        self.assertTrue(torch.all(targets["coords"] == torch.tensor([-100, 26])))
        self.assertTrue(torch.all(targets["types"] == torch.tensor([-100, 1])))

if name == "main":
    unittest.main()

================================================================================

文件路径: .\test\test_evaluators.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import unittest

import torch
from voxelcnn.evaluators import Accuracy

class TestAccuracy(unittest.TestCase):
    def test_forward(self):
        coords_outputs = torch.zeros((2, 1, 3, 3, 3))
        coords_outputs[0, 0, 0, 0, 0] = 1.0
        coords_outputs[1, 0, 1, 1, 1] = 1.0

        types_outputs = torch.zeros((2, 2, 3, 3, 3))
        types_outputs[0, 0, 0, 0, 0] = 1.0
        types_outputs[1, 1, 1, 1, 1] = 1.0

        outputs = {"coords": coords_outputs, "types": types_outputs}
        targets = {
            "coords": torch.tensor([[0, 1], [12, 13]]),
            "types": torch.tensor([[0, 0], [1, 1]]),
        }

        acc1 = Accuracy(next_steps=1)(outputs, targets).item()
        self.assertEqual(acc1, 0.5)

        acc2 = Accuracy(next_steps=2)(outputs, targets).item()
        self.assertEqual(acc2, 1.0)

================================================================================

文件路径: .\test\test_models.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import unittest

import torch
from voxelcnn.models import VoxelCNN

class TestVoxelCNN(unittest.TestCase):
    def test_forward(self):
        model = VoxelCNN()
        inputs = {
            "local": torch.rand(5, 256 * 3, 7, 7, 7),
            "global": torch.rand(5, 1, 21, 21, 21),
            "center": torch.randint(128, size=(5, 3)),
        }
        outputs = model(inputs)
        self.assertEqual(set(outputs.keys()), {"coords", "types", "center"})
        self.assertEqual(outputs["coords"].shape, (5, 1, 7, 7, 7))
        self.assertEqual(outputs["types"].shape, (5, 256, 7, 7, 7))
        self.assertEqual(outputs["center"].shape, (5, 3))

================================================================================

文件路径: .\test\test_predictor.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import unittest

import torch
from voxelcnn.predictor import Predictor

class TestPredictor(unittest.TestCase):
    def test_decode(self):
        coords_outputs = torch.zeros((2, 1, 3, 3, 3))
        coords_outputs[0, 0, 0, 0, 0] = 1.0
        coords_outputs[1, 0, 1, 1, 1] = 1.0

        types_outputs = torch.zeros((2, 2, 3, 3, 3))
        types_outputs[0, 0, 0, 0, 0] = 1.0
        types_outputs[1, 1, 1, 1, 1] = 1.0

        center = torch.tensor([[3, 3, 3], [10, 11, 12]])
        outputs = {"coords": coords_outputs, "types": types_outputs, "center": center}

        predictions = Predictor.decode(outputs)
        self.assertTrue(
            torch.all(predictions == torch.tensor([[0, 2, 2, 2], [1, 10, 11, 12]]))
        )

================================================================================

文件路径: .\voxelcnn\checkpoint.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import os
import shutil
from glob import glob
from os import path as osp
from typing import Any, Dict, Optional

import torch
from torch import nn, optim

class Checkpointer(object):
    def init(self, root_dir: str):
        """ Save and load checkpoints. Maintain best metrics

        Args:
            root_dir (str): Directory to save the checkpoints
        """
        super().init()
        self.root_dir = root_dir
        self.best_metric = -1
        self.best_epoch = None

    def save(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        scheduler: optim.lr_scheduler._LRScheduler,
        epoch: int,
        metric: float,
    ):
        if self.best_metric < metric:
            self.best_metric = metric
            self.best_epoch = epoch
            is_best = True
        else:
            is_best = False

        os.makedirs(self.root_dir, exist_ok=True)
        torch.save(
            {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "epoch": epoch,
                "best_epoch": self.best_epoch,
                "best_metric": self.best_metric,
            },
            osp.join(self.root_dir, f"{epoch:02d}.pth"),
        )

        if is_best:
            shutil.copy(
                osp.join(self.root_dir, f"{epoch:02d}.pth"),
                osp.join(self.root_dir, "best.pth"),
            )

    def load(
        self,
        load_from: str,
        model: Optional[nn.Module] = None,
        optimizer: Optional[optim.Optimizer] = None,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> Dict[str, Any]:
        ckp = torch.load(self._get_path(load_from))
        if model is not None:
            model.load_state_dict(ckp["model"])
        if optimizer is not None:
            optimizer.load_state_dict(ckp["optimizer"])
        if scheduler is not None:
            scheduler.load_state_dict(ckp["scheduler"])
        return ckp

    def resume(
        self,
        resume_from: str,
        model: Optional[nn.Module] = None,
        optimizer: Optional[optim.Optimizer] = None,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
    ) -> int:
        ckp = self.load(
            resume_from, model=model, optimizer=optimizer, scheduler=scheduler
        )
        self.best_epoch = ckp["best_epoch"]
        self.best_metric = ckp["best_metric"]
        return ckp["epoch"]

    def _get_path(self, load_from: str) -> str:
        if load_from == "best":
            return osp.join(self.root_dir, "best.pth")
        if load_from == "latest":
            return sorted(glob(osp.join(self.root_dir, "[0-9]*.pth")))[-1]
        if load_from.isnumeric():
            return osp.join(self.root_dir, f"{int(load_from):02d}.pth")
        return load_from

================================================================================

文件路径: .\voxelcnn\criterions.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from torch import nn

class CrossEntropyLoss(nn.Module):
    def init(self, ignore_index: int = -100):
        super().init()
        self.loss = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def forward(
        self, outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]
    ) -> Dict[str, torch.Tensor]:
        """ Compute CrossEntropyLoss for coordinates and block types predictions

        Args:
            outputs (dict): A dict of
                                {                     "coords": float tensor of shape (N, 1, D, D, D),                     "types": float tensor of shape (N, C, D, D, D),                 }                
                where N is the batch size, C is the number of block types, and D is the
                local size.

            targets (dict): A dict of
                                {                     "coords": int tensor of shape (N, A), the encoded target coordinates                     "types": int tensor of shape (N, A), the target block types                 }                
                where N is the batch size and A is the number of next groundtruth
                actions. However, we only use the next one action to compute the loss.

        Returns:
            A dict of losses: coords_loss, types_loss, overall_loss, where each is a
            tensor of float scalar
        """
        N, C, D, D, D = outputs["types"].shape
        assert outputs["coords"].shape == (N, 1, D, D, D)

        coords_targets = targets["coords"][:, 0].view(-1)
        types_targets = targets["types"][:, 0].view(-1)

        coords_outputs = outputs["coords"].view(N, -1)

        # Gather the type prediction on ground truth coordinate
        types_outputs = (
            outputs["types"]
            .view(N, C, D * D * D)
            .gather(dim=2, index=coords_targets.view(N, 1, 1).expand(N, C, 1))
            .view(N, -1)
        )

        coords_loss = self.loss(coords_outputs, coords_targets)
        types_loss = self.loss(types_outputs, types_targets)

        return {
            "coords_loss": coords_loss,
            "types_loss": types_loss,
            "overall_loss": coords_loss + types_loss,
        }

================================================================================

文件路径: .\voxelcnn\datasets.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import json
import logging
import os
import tarfile
import warnings
from os import path as osp
from typing import Dict, Optional, Tuple

import numpy as np
import requests
import torch
from torch.utils.data import Dataset

class Craft3DDataset(Dataset):
    NUM_BLOCK_TYPES = 256
    URL = "https://craftassist.s3-us-west-2.amazonaws.com/pubr/house_data.tar.gz"

    def init(
        self,
        data_dir: str,
        subset: str,
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
        next_steps: int = -1,
        max_samples: Optional[int] = None,
        logger: Optional[logging.Logger] = None,
    ):
        """ Download and construct 3D-Craft dataset

        data_dir (str): Directory to save/load the dataset
        subset (str): 'train' | 'val' | 'test'
        local_size (int): Local context size. Default: 7
        global_size (int): Global context size. Default: 21
        history (int): Number of previous steps considered as inputs. Default: 3
        next_steps (int): Number of next steps considered as targets. Default: -1,
            meaning till the end
        max_samples (int, optional): Limit the maximum number of samples. Used for
            faster debugging. Default: None, meaning no limit
        logger (logging.Logger, optional): A logger. Default: None, meaning will print
            to stdout
        """
        super().init()
        self.data_dir = data_dir
        self.subset = subset
        self.local_size = local_size
        self.global_size = global_size
        self.history = history
        self.max_local_distance = self.local_size // 2
        self.max_global_distance = self.global_size // 2
        self.next_steps = next_steps
        self.max_samples = max_samples
        self.logger = logger

        if self.subset not in ("train", "val", "test"):
            raise ValueError(f"Unknown subset: {self.subset}")

        if not self._has_raw_data():
            self._download()

        self._load_dataset()
        self._find_valid_items()

        self.print_stats()

    def print_stats(self):
        num_blocks_per_house = [len(x) for x in self._valid_indices.values()]
        ret = "\n"
        ret += f"3D Craft Dataset\n"
        ret += f"================\n"
        ret += f"  data_dir: {self.data_dir}\n"
        ret += f"  subset: {self.subset}\n"
        ret += f"  local_size: {self.local_size}\n"
        ret += f"  global_size: {self.global_size}\n"
        ret += f"  history: {self.history}\n"
        ret += f"  next_steps: {self.next_steps}\n"
        ret += f"  max_samples: {self.max_samples}\n"
        ret += f"  --------------\n"
        ret += f"  num_houses: {len(self._valid_indices)}\n"
        ret += f"  avg_blocks_per_house: {np.mean(num_blocks_per_house):.3f}\n"
        ret += f"  min_blocks_per_house: {min(num_blocks_per_house)}\n"
        ret += f"  max_blocks_per_house: {max(num_blocks_per_house)}\n"
        ret += f"  total_valid_blocks: {len(self._flattened_valid_indices)}\n"
        ret += "\n"
        self._log(ret)

    def len(self) -> int:
        """ Get number of valid blocks """
        ret = len(self._flattened_valid_indices)
        if self.max_samples is not None:
            ret = min(ret, self.max_samples)
        return ret

    def getitem(
        self, index: int
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        """ Get the index-th valid block

        Returns:
            A tuple of inputs and targets, where inputs is a dict of
                        {                 "local": float tensor of shape (C * H, D, D, D),                 "global": float tensor of shape (1, G, G, G),                 "center": int tensor of shape (3,), the coordinate of the last block             }            
            where C is the number of block types, H is the history length, D is the
            local size, and G is the global size.

            targets is a dict of
                        {                 "coords": int tensor of shape (A,)                 "types": int tensor of shape (A,)             }            
            where A is the number of next steps to be considered as targets.
        """
        house_id, block_id = self._flattened_valid_indices[index]
        annotation = self._all_houses[house_id]
        inputs = Craft3DDataset.prepare_inputs(
            annotation[: block_id + 1],
            local_size=self.local_size,
            global_size=self.global_size,
            history=self.history,
        )
        targets = Craft3DDataset.prepare_targets(
            annotation[block_id:],
            next_steps=self.next_steps,
            local_size=self.local_size,
        )
        return inputs, targets

    def get_house(self, index: int) -> torch.Tensor:
        """ Get the annotation for the index-th house. Use for thorough evaluation """
        return self._all_houses[index]

    def get_num_houses(self) -> int:
        """ Get the total number of houses. Use for thorough evaluation """
        return len(self._all_houses)

    @staticmethod
    @torch.no_grad()
    def prepare_inputs(
        annotation: torch.Tensor,
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
    ) -> Dict[str, torch.Tensor]:
        """ Convert annotation to input tensors

        Args:
            annotation (torch.Tensor): M x 4 int tensor, where M is the number of
                prebuilt blocks. The first column is the block type, followed by the
                block coordinates.

        Returns:
                        {                 "local": float tensor of shape (C * H, D, D, D),                 "global": float tensor of shape (1, G, G, G),                 "center": int tensor of shape (3,), the coordinate of the last block             }            
            where C is the number of block types, H is the history length, D is the
            local size, and G is the global size.
        """
        global_inputs = Craft3DDataset._convert_to_voxels(
            annotation, size=global_size, occupancy_only=True
        )
        local_inputs = Craft3DDataset._convert_to_voxels(
            annotation, size=local_size, occupancy_only=False
        )
        if len(annotation) == 0:
            return {
                "local": local_inputs.repeat(history, 1, 1, 1),
                "global": global_inputs,
                "center": torch.zeros((3,), dtype=torch.int64),
            }

        last_coord = annotation[-1, 1:]
        center_coord = last_coord.new_full((3,), local_size // 2)
        local_history = [local_inputs]
        for i in range(len(annotation) - 1, len(annotation) - history, -1):
            if i < 0:
                local_history.append(torch.zeros_like(local_inputs))
            else:
                prev_inputs = local_history[-1].clone()
                prev_coord = annotation[i, 1:] - last_coord + center_coord
                if all((prev_coord >= 0) & (prev_coord < local_size)):
                    x, y, z = prev_coord
                    prev_inputs[:, x, y, z] = 0
                local_history.append(prev_inputs)
        local_inputs = torch.cat(local_history, dim=0)
        return {"local": local_inputs, "global": global_inputs, "center": last_coord}

    @staticmethod
    @torch.no_grad()
    def prepare_targets(
        annotation: torch.Tensor, next_steps: int = 1, local_size: int = 7
    ) -> Dict[str, torch.Tensor]:
        """ Convert annotation to target tensors

        Args:
            annotation (torch.Tensor): (M + 1) x 4 int tensor, where M is the number of
                blocks to build, plus one for the last built block. The first column
                is the block type, followed by the block coordinates.

        Returns:
                        {                 "coords": int tensor of shape (A,)                 "types": int tensor of shape (A,)             }            
            where A is the number of next steps to be considered as targets
        """
        coords_targets = torch.full((next_steps,), -100, dtype=torch.int64)
        types_targets = coords_targets.clone()

        if len(annotation) <= 1:
            return {"coords": coords_targets, "types": types_targets}

        offsets = torch.tensor([local_size * local_size, local_size, 1])
        last_coord = annotation[0, 1:]
        center_coord = last_coord.new_full((3,), local_size // 2)

        N = min(1 + next_steps, len(annotation))
        next_types = annotation[1:N, 0].clone()
        next_coords = annotation[1:N, 1:] - last_coord + center_coord
        mask = (next_coords < 0) | (next_coords >= local_size)
        mask = mask.any(dim=1)
        next_coords = (next_coords * offsets).sum(dim=1)
        next_coords[mask] = -100
        next_types[mask] = -100

        coords_targets[: len(next_coords)] = next_coords
        types_targets[: len(next_types)] = next_types

        return {"coords": coords_targets, "types": types_targets}

    @staticmethod
    def _convert_to_voxels(
        annotation: torch.Tensor, size: int, occupancy_only: bool = False
    ) -> torch.Tensor:
        voxels_shape = (
            (1, size, size, size)
            if occupancy_only
            else (Craft3DDataset.NUM_BLOCK_TYPES, size, size, size)
        )
        if len(annotation) == 0:
            return torch.zeros(voxels_shape, dtype=torch.float32)

        annotation = annotation.clone()
        if occupancy_only:
            # No block types. Just coordinate occupancy
            annotation[:, 0] = 0
        # Shift the coordinates to make the last block centered
        last_coord = annotation[-1, 1:]
        center_coord = last_coord.new_tensor([size // 2, size // 2, size // 2])
        annotation[:, 1:] += center_coord - last_coord
        # Find valid annotation that inside the cube
        valid_mask = (annotation[:, 1:] >= 0) & (annotation[:, 1:] < size)
        valid_mask = valid_mask.all(dim=1)
        annotation = annotation[valid_mask]
        # Use sparse tensor to construct the voxels cube
        return torch.sparse.FloatTensor(
            annotation.t(), torch.ones(len(annotation)), voxels_shape
        ).to_dense()

    def _log(self, msg: str):
        if self.logger is None:
            print(msg)
        else:
            self.logger.info(msg)

    def _has_raw_data(self) -> bool:
        return osp.isdir(osp.join(self.data_dir, "houses"))

    def _download(self):
        os.makedirs(self.data_dir, exist_ok=True)

        tar_path = osp.join(self.data_dir, "houses.tar.gz")
        if not osp.isfile(tar_path):
            self._log(f"Downloading dataset from {Craft3DDataset.URL}")
            response = requests.get(Craft3DDataset.URL, allow_redirects=True)
            if response.status_code != 200:
                raise RuntimeError(
                    f"Failed to retrieve image from url: {Craft3DDataset.URL}. "
                    f"Status: {response.status_code}"
                )
            with open(tar_path, "wb") as f:
                f.write(response.content)

        extracted_dir = osp.join(self.data_dir, "houses")
        if not osp.isdir(extracted_dir):
            self._log(f"Extracting dataset to {extracted_dir}")
            tar = tarfile.open(tar_path, "r")
            tar.extractall(self.data_dir)

    def _load_dataset(self):
        splits_path = osp.join(self.data_dir, "splits.json")
        if not osp.isfile(splits_path):
            raise RuntimeError(f"Split file not found at: {splits_path}")

        with open(splits_path, "r") as f:
            splits = json.load(f)

        self._all_houses = []
        max_len = 0
        for filename in splits[self.subset]:
            annotation = osp.join(self.data_dir, "houses", filename, "placed.json")
            if not osp.isfile(annotation):
                warnings.warn(f"No annotation file for: {annotation}")
                continue
            annotation = self._load_annotation(annotation)
            if len(annotation) >= 100:
                self._all_houses.append(annotation)
                max_len = max(max_len, len(annotation))

        if self.next_steps <= 0:
            self.next_steps = max_len

    def _load_annotation(self, annotation_path: str) -> torch.Tensor:
        with open(annotation_path, "r") as f:
            annotation = json.load(f)
        final_house = {}
        types_and_coords = []
        last_timestamp = -1
        for i, item in enumerate(annotation):
            timestamp, annotator_id, coordinate, block_info, action = item
            assert timestamp >= last_timestamp
            last_timestamp = timestamp
            coordinate = tuple(np.asarray(coordinate).astype(np.int64).tolist())
            block_type = np.asarray(block_info, dtype=np.uint8).astype(np.int64)[0]
            if action == "B":
                final_house.pop(coordinate, None)
            else:
                final_house[coordinate] = i
            types_and_coords.append((block_type,) + coordinate)
        indices = sorted(final_house.values())
        types_and_coords = [types_and_coords[i] for i in indices]
        return torch.tensor(types_and_coords, dtype=torch.int64)

    def _find_valid_items(self):
        self._valid_indices = {}
        for i, annotation in enumerate(self._all_houses):
            diff_coord = annotation[:-1, 1:] - annotation[1:, 1:]
            valids = abs(diff_coord) <= self.max_local_distance
            valids = valids.all(dim=1).nonzero(as_tuple=True)[0]
            self._valid_indices[i] = valids.tolist()

        self._flattened_valid_indices = []
        for i, indices in self._valid_indices.items():
            for j in indices:
                self._flattened_valid_indices.append((i, j))

if name == "main":
    work_dir = osp.join(osp.dirname(osp.abspath(file)), "..")
    dataset = Craft3DDataset(osp.join(work_dir, "data"), "val")
    for i in range(5):
        inputs, targets = dataset[i]
        print(targets)

================================================================================

文件路径: .\voxelcnn\evaluators.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, Tuple

import torch
from torch import nn
from tqdm import tqdm

from .datasets import Craft3DDataset
from .predictor import Predictor

class Accuracy(nn.Module):
    def init(self, next_steps: int = 1):
        """ Compute the accuracy of coordinates and types predictions

        Args:
            next_steps (int): The number of future ground truth steps to be considered.
                Default: 1
        """
        super().init()
        self.next_steps = next_steps

    def step(
        self, outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """ Compute sample-wise accuracy in a minibatch

        Args:
            outputs (dict): A dict of
                                {                     "coords": float tensor of shape (N, 1, D, D, D),                     "types": float tensor of shape (N, C, D, D, D),                 }                
                where N is the batch size, C is the number of block types, and D is the
                local size.

            targets (dict): A dict of
                                {                     "coords": int tensor of shape (N, A), the encoded target coordinates                     "types": int tensor of shape (N, A), the target block types                 }                
                where N is the batch size and A is the number of next groundtruth
                actions.
        """
        N, C, D, D, D = outputs["types"].shape
        assert outputs["coords"].shape == (N, 1, D, D, D)
        assert targets["coords"].shape == targets["types"].shape

        K = self.next_steps if self.next_steps > 0 else targets["coords"].shape[1]
        if targets["coords"].shape[1] < K:
            raise RuntimeError(f"Targets do not contain next {K} steps")

        coords_targets = targets["coords"][:, :K].view(N, -1)
        types_targets = targets["types"][:, :K].view(N, -1)

        coords_predictions = outputs["coords"].view(N, -1).argmax(dim=1, keepdim=True)
        coords_correct = (coords_predictions == coords_targets).any(dim=1)

        types_predictions = (
            outputs["types"]
            .view(N, C, D * D * D)
            .gather(dim=2, index=coords_predictions.view(N, 1, 1).expand(N, C, 1))
            .argmax(dim=1)
            .view(N, -1)
        )
        types_correct = (types_predictions == types_targets).any(dim=1)

        both_correct = coords_correct & types_correct
        return both_correct

    def stop(self, correct: torch.Tensor) -> torch.Tensor:
        """ Average over batched results

        Args:
            correct (torch.Tensor): (N,) bool vector, whether each sample is correct

        Returns:
            A float scalar tensor of averaged accuracy
        """
        return correct.float().mean()

    def forward(
        self, outputs: Dict[str, torch.Tensor], targets: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        return self.stop(self.step(outputs, targets))

class CCA(object):
    def init(
        self,
        percentages: Tuple[float] = (0.1, 0.25, 0.5, 0.75, 0.9),
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
    ):
        """ Consecutive Correct Actions

        Args:
            percentages (tuple): Evaluate based on these percentages of blocks
                prebuilt for each house. Average the CCA over these results
            local_size (int): Local context size. Default: 7
            global_size (int): Global context size. Default: 21
            history (int): Number of previous steps considered as inputs. Default: 3
        """
        super().init()
        self.percentages = percentages
        self.local_size = local_size
        self.global_size = global_size
        self.history = history

    @torch.no_grad()
    def evaluate(self, dataset: Craft3DDataset, model: nn.Module) -> Dict[str, float]:
        """ Evaluate a model by CCA over a given dataset

        Args:
            dataset (Craft3DDataset): A dataset
            model (nn.Module): A VoxelCNN model

        Returns:
            A dict of string to float
                            {                     "cca_x%": where x is the percentage of prebuilt house,                     "cca_avg": averaged CCA across all the percentages,                 }            
        """
        predictor = Predictor(
            model.eval(),
            local_size=self.local_size,
            global_size=self.global_size,
            history=self.history,
        )
        all_results = defaultdict(list)
        for i in tqdm(range(dataset.get_num_houses())):
            house = dataset.get_house(i)
            for p in self.percentages:
                start = int(len(house) * p)
                predictions = predictor.predict_until_wrong(house, start=start)
                all_results[p].append(len(predictions))

        results = {
            f"cca_{k:.0%}": float(torch.tensor(v).float().mean())
            for k, v in all_results.items()
        }
        results["cca_avg"] = float(torch.tensor(list(results.values())).float().mean())
        return results

class MTC(object):
    def init(
        self,
        percentage: float = 0.1,
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
    ):
        """ Mistakes to Complete

        Args:
            percentage (float): Evaluate based on this percentage of blocks
                prebuilt for each house
            local_size (int): Local context size. Default: 7
            global_size (int): Global context size. Default: 21
            history (int): Number of previous steps considered as inputs. Default: 3
        """
        super().init()
        self.percentage = percentage
        self.local_size = local_size
        self.global_size = global_size
        self.history = history

    @torch.no_grad()
    def evaluate(self, dataset: Craft3DDataset, model: nn.Module) -> Dict[str, float]:
        """ Evaluate a model by MTC over a given dataset

        Args:
            dataset (Craft3DDataset): A dataset
            model (nn.Module): A VoxelCNN model

        Returns:
            A dict of string to float
                            {                     "mtc": Unnormalized MTC,                     "mtc_normed": Normalized MTC by the total blocks of each house,                 }            
        """
        predictor = Predictor(
            model.eval(),
            local_size=self.local_size,
            global_size=self.global_size,
            history=self.history,
        )
        unnormed = []
        normed = []
        for i in tqdm(range(dataset.get_num_houses())):
            house = dataset.get_house(i)
            start = int(len(house) * self.percentage)
            is_correct = predictor.predict_until_end(house, start=start)["is_correct"]
            num_mistakes = len(is_correct) - int(is_correct.sum())
            total = len(is_correct)
            unnormed.append(num_mistakes)
            normed.append(num_mistakes / total)
        return {
            "mtc": float(torch.tensor(unnormed).float().mean()),
            "mtc_normed": float(torch.tensor(normed).float().mean()),
        }

    def _compute(self, predictor: Predictor, annotation: torch.Tensor) -> int:
        start = int(len(annotation) * self.percentage)
        is_correct = predictor.predict_until_end(annotation, start=start)["is_correct"]
        num_mistakes = len(is_correct) - is_correct.sum()
        return num_mistakes

================================================================================

文件路径: .\voxelcnn\models.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

from typing import Dict, List

import torch
from torch import nn

def conv3d(
    in_channels: int,
    out_channels: int,
    kernel_size: int = 3,
    stride: int = 1,
    padding: int = 1,
) -> List[nn.Module]:
    conv = nn.Conv3d(
        in_channels,
        out_channels,
        kernel_size,
        padding=padding,
        stride=stride,
        bias=False,
    )
    bn = nn.BatchNorm3d(out_channels)
    relu = nn.ReLU()
    return [conv, bn, relu]

class VoxelCNN(nn.Module):
    def init(
        self,
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
        num_block_types: int = 256,
        num_features: int = 16,
    ):
        """ VoxelCNN model

        Args:
            local_size (int): Local context size. Default: 7
            global_size (int): Global context size. Default: 21
            history (int): Number of previous steps considered as inputs. Default: 3
            num_block_types (int): Total number of different block types. Default: 256
            num_features (int): Number of channels output by the encoders. Default: 16
        """
        super().init()
        self.local_size = local_size
        self.global_size = global_size
        self.history = history
        self.num_block_types = num_block_types
        self.num_features = num_features

        self.local_encoder = self._build_local_encoder()
        self.global_encoder = self._build_global_encoder()

        self.feature_extractor = nn.Sequential(
            *conv3d(self.num_features * 2, self.num_features, kernel_size=1, padding=0)
        )
        self.coords_predictor = nn.Conv3d(
            self.num_features, 1, kernel_size=1, padding=0
        )
        self.types_predictor = nn.Conv3d(
            self.num_features, self.num_block_types, kernel_size=1, padding=0
        )

        self._init_params()

    def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Args:
            inputs (dict): A dict of inputs
                                {                     "local": float tensor of shape (N, C * H, D, D, D),                     "global": float tensor of shape (N, 1, G, G, G),                     "center": int tensor of shape (N, 3), the coordinate of the last                         blocks, optional                 }                
                where N is the batch size, C is the number of block types, H is the
                history length, D is the local size, and G is the global size.

        Returns:
            A dict of coordinates and types scores
                        {                 "coords": float tensor of shape (N, 1, D, D, D),                 "types": float tensor of shape (N, C, D, D, D),                 "center": int tensor of shape (N, 3), the coordinate of the last blocks.                     Output only when inputs have "center"             }            
        """
        outputs = torch.cat(
            [
                self.local_encoder(inputs["local"]),
                self.global_encoder(inputs["global"]),
            ],
            dim=1,
        )
        outputs = self.feature_extractor(outputs)
        ret = {
            "coords": self.coords_predictor(outputs),
            "types": self.types_predictor(outputs),
        }
        if "center" in inputs:
            ret["center"] = inputs["center"]
        return ret

    def _build_local_encoder(self) -> nn.Module:
        layers = conv3d(self.num_block_types * self.history, self.num_features)
        for _ in range(3):
            layers.extend(conv3d(self.num_features, self.num_features))
        return nn.Sequential(*layers)

    def _build_global_encoder(self) -> nn.Module:
        layers = conv3d(1, self.num_features)
        layers.extend(conv3d(self.num_features, self.num_features))
        layers.append(
            nn.AdaptiveMaxPool3d((self.local_size, self.local_size, self.local_size))
        )
        layers.extend(conv3d(self.num_features, self.num_features))
        return nn.Sequential(*layers)

    def init_params(self):
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                if m.bias is None:
                    # Normal Conv3d layers
                    nn.init.kaiming_normal
(m.weight, mode="fan_out")
                else:
                    # Last layers of coords and types predictions
                    nn.init.normal_(m.weight, mean=0, std=0.001)
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

================================================================================

文件路径: .\voxelcnn\predictor.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

from typing import Dict

import torch
from torch import nn

from .datasets import Craft3DDataset
from .utils import OrderedSet

class Predictor(object):
    def init(
        self,
        model: nn.Module,
        local_size: int = 7,
        global_size: int = 21,
        history: int = 3,
    ):
        """ Predictor for inference and evaluation

        Args:
            model (nn.Module): VoxelCNN model
            local_size (int): Local context size. Default: 7
            global_size (int): Global context size. Default: 21
            history (int): Number of previous steps considered as inputs. Default: 3
        """
        super().init()
        self.model = model
        self.local_size = local_size
        self.global_size = global_size
        self.history = history

    @torch.no_grad()
    def predict(self, annotation: torch.Tensor, steps: int = 1) -> torch.Tensor:
        """ Continuous prediction for given steps starting from a prebuilt house

        Args:
            annotation (torch.Tensor): M x 4 int tensor, where M is the number of
                prebuilt blocks. The first column is the block type, followed by the
                absolute block coordinates.
            steps (int): How many steps to predict. Default: 1

        Returns:
            An int tensor of (steps, 4) if steps > 1, otherwise (4,). Denoting the
            predicted blocks. The first column is the block type, followed by the
            absolute block coordinates.
        """
        predictions = []
        for _ in range(steps):
            inputs = Craft3DDataset.prepare_inputs(
                annotation,
                local_size=self.local_size,
                global_size=self.global_size,
                history=self.history,
            )
            inputs = {k: v.unsqueeze(0) for k, v in inputs.items()}
            if next(self.model.parameters()).is_cuda:
                inputs = {k: v.cuda() for k, v in inputs.items()}
            outputs = self.model(inputs)
            prediction = self.decode(outputs).cpu()
            predictions.append(prediction)
            annotation = torch.cat([annotation, prediction], dim=0)
        predictions = torch.cat(predictions, dim=0)
        return predictions.squeeze()

    @torch.no_grad()
    def predict_until_wrong(
        self, annotation: torch.Tensor, start: int = 0
    ) -> torch.Tensor:
        """ Starting from a house, predict until a wrong prediction occurs

        Args:
            annotation (torch.Tensor): M x 4 int tensor, where M is the number of
                prebuilt blocks. The first column is the block type, followed by the
                absolute block coordinates.
            start (int): Starting from this number of blocks prebuilt

        Returns:
            An int tensor of (steps, 4). Denoting the correctly predicted blocks. The
                first column is the block type, followed by the absolute block
                coordinates.
        """
        built = annotation[:start].tolist()
        to_build = {tuple(x) for x in annotation[start:].tolist()}
        predictions = []
        while len(to_build) > 0:
            block = self.predict(torch.tensor(built, dtype=torch.int64)).tolist()
            if tuple(block) not in to_build:
                break
            predictions.append(block)
            built.append(block)
            to_build.remove(tuple(block))
        return torch.tensor(predictions)

    @torch.no_grad()
    def predict_until_end(
        self, annotation: torch.Tensor, start: int = 0
    ) -> Dict[str, torch.Tensor]:
        """ Starting from a house, predict until a the house is completed

        Args:
            annotation (torch.Tensor): M x 4 int tensor, where M is the number of
                prebuilt blocks. The first column is the block type, followed by the
                absolute block coordinates.
            start (int): Starting from this number of blocks prebuilt

        Returns:
            A dict of
                        {                 "predictions": int tensor of shape (steps, 4),                 "targets": int tensor of shape (steps, 4),                 "is_correct": bool tensor of shape (steps,)             }            
            where steps is the number of blocks predicted, in order to complete the
            house. predictions is the model predictions, which could be wrong at
            some steps. targets contains the corrected blocks. is_correct
            denotes whether the prediction is correct at certain step.
        """
        built = annotation[:start].tolist()
        to_build = OrderedSet(tuple(x) for x in annotation[start:].tolist())

        predictions = []
        targets = []
        is_correct = []
        while len(to_build) > 0:
            block = self.predict(torch.tensor(built, dtype=torch.int64)).tolist()
            predictions.append(block)
            if tuple(block) in to_build:
                # Prediction is correct. Add it to the house
                is_correct.append(True)
                targets.append(block)
                built.append(block)
                to_build.remove(tuple(block))
            else:
                # Prediction is wrong. Correct it by the ground truth block
                is_correct.append(False)
                correction = to_build.pop(last=False)
                targets.append(correction)
                built.append(correction)

        return {
            "predictions": torch.tensor(predictions),
            "targets": torch.tensor(targets),
            "is_correct": torch.tensor(is_correct),
        }

    @staticmethod
    def decode(outputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        """ Convert model output scores to absolute block coordinates and types

        Args:
            outputs (dict): A dict of coordinates and types scores
                                {                     "coords": float tensor of shape (N, 1, D, D, D),                     "types": float tensor of shape (N, C, D, D, D),                     "center": int tensor of shape (N, 3), the coordinate of the last                         blocks                 }                
                where N is the batch size, C is the number of block types, D is the
                local context size
        Returns:
            An int tensor of shape (N, 4), where the first column is the block type,
            followed by absolute block coordinates
        """
        N, C, D, D, D = outputs["types"].shape
        assert outputs["coords"].shape == (N, 1, D, D, D)
        assert outputs["center"].shape == (N, 3)

        coords_predictions = outputs["coords"].view(N, -1).argmax(dim=1)

        types_predictions = (
            outputs["types"]
            .view(N, C, D * D * D)
            .gather(dim=2, index=coords_predictions.view(N, 1, 1).expand(N, C, 1))
            .argmax(dim=1)
            .view(-1)
        )

        z = coords_predictions % D
        x = coords_predictions // (D * D)
        y = (coords_predictions - z - x * (D * D)) // D
        ret = torch.stack([types_predictions, x, y, z], dim=1)
        ret[:, 1:] += outputs["center"] - ret.new_tensor([D, D, D]) // 2

        return ret

if name == "main":
    from .models import VoxelCNN

    model = VoxelCNN()
    predictor = Predictor(model.eval())
    annotation = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]])
    results = predictor.predict_until_end(annotation)
    print(results)

================================================================================

文件路径: .\voxelcnn\summary.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import logging
from collections import defaultdict
from typing import Dict, List, Optional, Union

import numpy as np
import torch

class Summary(object):
    Datum = Dict[str, Union[float, torch.Tensor]]

    def init(self, window_size: int = 20, logger: Optional[logging.Logger] = None):
        """ Training summary helper

        Args:
            window_size (int): Compute the moving average of scalars within this number
                of history values
            logger (logging.Logger, optional): A logger. Default: None, meaning will
                print to stdout
        """
        super().init()
        self.window_size = window_size
        self.logger = logger
        self._times = defaultdict(list)
        self._lrs = defaultdict(list)
        self._losses = defaultdict(list)
        self._metrics = defaultdict(list)

    def add_times(self, times: Datum):
        for k, v in times.items():
            if isinstance(v, torch.Tensor):
                v = float(v)
            self._times[k].append(v)

    def add_lrs(self, lrs: Datum):
        for k, v in lrs.items():
            if isinstance(v, torch.Tensor):
                v = float(v)
            self._lrs[k].append(v)

    def add_losses(self, losses: Datum):
        for k, v in losses.items():
            if isinstance(v, torch.Tensor):
                v = float(v)
            self._losses[k].append(v)

    def add_metrics(self, metrics: Datum):
        for k, v in metrics.items():
            if isinstance(v, torch.Tensor):
                v = float(v)
            self._metrics[k].append(v)

    def add(
        self,
        times: Optional[Datum] = None,
        lrs: Optional[Datum] = None,
        losses: Optional[Datum] = None,
        metrics: Optional[Datum] = None,
    ):
        if times is not None:
            self.add_times(times)
        if lrs is not None:
            self.add_lrs(lrs)
        if losses is not None:
            self.add_losses(losses)
        if metrics is not None:
            self.add_metrics(metrics)

    def print_current(self, prefix: Optional[str] = None):
        items = [] if prefix is None else [prefix]
        items += [f"{k}: {v[-1]:.6f}" for k, v in list(self._lrs.items())]
        items += [
            f"{k}: {v[-1]:.3f} ({self._moving_average(v):.3f})"
            for k, v in list(self._times.items())
            + list(self._losses.items())
            + list(self._metrics.items())
        ]
        self._log("  ".join(items))

    def _moving_average(self, values: List[float]) -> float:
        return np.mean(values[max(0, len(values) - self.window_size) :])

    def _log(self, msg: str):
        if self.logger is None:
            print(msg)
        else:
            self.logger.info(msg)

================================================================================

文件路径: .\voxelcnn\utils.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

import itertools
import logging
import operator
import os
import sys
from collections import OrderedDict
from os import path as osp
from time import time as tic
from typing import Dict, List, Tuple, Union

import torch

StructuredData = Union[
    Dict[str, "StructuredData"],
    List["StructuredData"],
    Tuple["StructuredData"],
    torch.Tensor,
]

def to_cuda(data: StructuredData) -> StructuredData:
    if isinstance(data, torch.Tensor):
        return data.cuda(non_blocking=True)
    if isinstance(data, dict):
        return {k: to_cuda(v) for k, v in data.items()}
    if isinstance(data, (list, tuple)):
        return type(data)(to_cuda(x) for x in data)
    raise ValueError(f"Unknown data type: {type(data)}")

def collate_batches(batches):
    """ Collate a list of batches into a batch

    Args:
        batches (list): a list of batches. Each batch could be a tensor, dict, tuple,
            list, string, number, namedtuple

    Returns:
        batch: collated batches where tensors are concatenated along the outer dim.
            For example, given samples [torch.empty(3, 5), torch.empty(3, 5)], the
            result will be a tensor of shape (6, 5).
    """
    batch = batches[0]
    if isinstance(batch, torch.Tensor):
        return batch if len(batches) == 1 else torch.cat(batches, 0)
    if isinstance(batch, (list, tuple)):
        transposed = zip(*batches)
        return type(batch)([collate_batches(b) for b in transposed])
    if isinstance(batch, dict):
        return {k: collate_batches([d[k] for d in batches]) for k in batch}
    # Otherwise, just return the input as it is
    return batches

def setup_logger(name=None, save_file=None, rank=0, level=logging.DEBUG):
    """ Setup logger once for each process

    Logging messages will be printed to stdout, with DEBUG level. If save_file is set,
    messages will be logged to disk files as well.

    When multiprocessing, this function must be called inside each process, but only
    the main process (rank == 0) will log to stdout and files.

    Args:
        name (str, optional): Name of the logger. Default: None, will use the root
            logger
        save_file (str, optional): Path to a file where log messages are saved into.
            Default: None, do not log to file
        rank (int): Rank of the process. Default: 0, the main process
        level (int): An integer of logging level. Recommended to be one of
            logging.DEBUG / INFO / WARNING / ERROR / CRITICAL. Default: logging.DEBUG
    """
    logger = logging.getLogger(name)
    logger.setLevel(level)
    logger.propagate = False

    # Don't log results for the non-main process
    if rank > 0:
        logging.disable(logging.CRITICAL)
        return logger

    sh = logging.StreamHandler(stream=sys.stdout)
    sh.setLevel(level)
    formatter = logging.Formatter(
        fmt="%(asctime)s %(levelname)s %(filename)s:%(lineno)4d: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    sh.setFormatter(formatter)
    logger.addHandler(sh)
    logger.propagate = False  # prevent double logging

    if save_file is not None:
        os.makedirs(osp.dirname(osp.abspath(save_file)), exist_ok=True)
        fh = logging.FileHandler(save_file)
        fh.setLevel(level)
        fh.setFormatter(formatter)
        logger.addHandler(fh)

    return logger

class Section(object):
    """
    Examples
    --------
    >>> with Section('Loading Data'):
    >>>     num_samples = load_data()
    >>>     print(f'=> {num_samples} samples loaded')

    will print out something like
    => Loading data ...
    => 42 samples loaded
    => Done!
    """

    def init(self, description, newline=True, logger=None, timing="auto"):
        super(Section, self).init()
        self.description = description
        self.newline = newline
        self.logger = logger
        self.timing = timing
        self.t0 = None
        self.t1 = None

    def enter(self):
        self.t0 = tic()
        self.print_message("=> " + str(self.description) + " ...")

    def exit(self, type, value, traceback):
        self.t1 = tic()
        msg = "=> Done"
        if self.timing != "none":
            t, unit = self._get_time_and_unit(self.t1 - self.t0)
            msg += f" in {t:.3f} {unit}"
        self.print_message(msg)
        if self.newline:
            self.print_message()

    def print_message(self, message=""):
        if self.logger is None:
            try:
                print(message, flush=True)
            except TypeError:
                print(message)
                sys.stdout.flush()
        else:
            self.logger.info(message)

    def _get_time_and_unit(self, time):
        if self.timing == "auto":
            return self._auto_determine_time_and_unit(time)
        elif self.timing == "us":
            return time * 1000000, "us"
        elif self.timing == "ms":
            return time * 1000, "ms"
        elif self.timing == "s":
            return time, "s"
        elif self.timing == "m":
            return time / 60, "m"
        elif self.timing == "h":
            return time / 3600, "h"
        else:
            raise ValueError(f"Unknown timing mode: {self.timing}")

    def _auto_determine_time_and_unit(self, time):
        if time < 0.001:
            return time * 1000000, "us"
        elif time < 1.0:
            return time * 1000, "ms"
        elif time < 60:
            return time, "s"
        elif time < 3600:
            return time / 60, "m"
        else:
            return time / 3600, "h"

class OrderedSet(object):
    """ A set that remembers insertion order """

    def init(self, iterable=None):
        self.__dict = OrderedDict()
        for value in iterable or []:
            self.add(value)

    def add(self, value):
        if value not in self.__dict:
            self.__dict[value] = None

    def remove(self, value):
        if value in self.__dict:
            del self.__dict[value]
        else:
            raise KeyError(value)

    def pop(self, last=True):
        """ Pop the last or first element

        Args:
            last (bool): If True, pop the last element (most recently inserted).
                Otherwise pop the first element (oldest).
        """
        return self.__dict.popitem(last=last)[0]

    def iter(self):
        return self.__dict.iter()

    def len(self):
        return self.__dict.len()

    def contains(self, value):
        return value in self.__dict

    def intersection(self, other):
        new = OrderedSet()
        for value in self:
            if value in other:
                new.add(value)
        return new

    def and(self, other):
        return self.intersection(other)

    def union(self, other):
        new = OrderedSet()
        for value in itertools.chain(self, other):
            new.add(value)
        return new

    def or(self, other):
        return self.union(other)

    def ge(self, other):
        return set(self.__dict).ge(set(other.__dict))

    def eq(self, other):
        if isinstance(other, OrderedSet):
            return len(self) == len(other) and all(map(operator.eq, self, other))
        return set.eq(set(self), other)

    def ne(self, other):
        return not self == other

    def repr(self):
        return f"{self.class.name}([{', '.join(repr(x) for x in self)}])"

    def str(self):
        return self.repr()

================================================================================

文件路径: .\voxelcnn_init_.py

#!/usr/bin/env python3

Copyright (c) Facebook, Inc. and its affiliates.

All rights reserved.

This source code is licensed under the license found in the

LICENSE file in the root directory of this source tree.

================================================================================

विचार

Designing model for block IDs assignment

This task involves creating a model that takes block coordinates as input and returns predicted block IDs. I'll need to design an appropriate dataset and include a loss function that updates after every training. Each epoch must run an inference once. The model should also allow generation of "stick houses" with unknown block IDs, inferring connected block IDs. Batch size should be adjustable via the command line, and I must ensure the model saves to the logs directory. Finally, I'll make sure to provide usage instructions along with the .py script.

Designing model architecture

It seems the user wants a solution for predicting block IDs using something like a transformer-based model or possible DiT (Dilated CNN?). I'll start thinking about an architecture where the input would be 3D coordinates (x, y, z) and maybe some global context features. For simplicity, I might opt for embedding coordinates into positional encoding and feeding that into a simple transformer model. PyTorch Lightning could be useful here to handle training loops. I might also consider relying on existing code if time is tight.

Designing block ID prediction model

I'll design a dataset class for block ID classification, adapting the Craft3DDataset to include occupancy data as input, and block type as output. Since the user wants input of coordinates without block IDs, the dataset will include occupancy grids and masks for unknown types, with the model inferring IDs. For efficiency, I'll create a simple CNN or transformer-based model. I'll also include the ability to set batch sizes and epochs via command-line arguments, and implement PyTorch Lightning for training. After implementing the code, I'll save it as a script for the user to download.

Thought for 51 seconds

उत्तर

我已经把改好的完整脚本保存为 dit_blocktypes.py。直接下载后即可运行:

下载 dit_blocktypes.py


快速使用说明

  1. 安装依赖
bash
pip install torch torchvision torchaudio pytorch-lightning numpy
  1. 准备数据
    把官方 3D-Craft 数据集解压到任意目录,例如 data/(里边应包含 houses/splits.json)。

  2. 训练

bash
python dit_blocktypes.py --data_dir data --batch_size 128 --epochs 12
  • 模型与日志会保存在 logs/ 目录。
  • 每一步都会在进度条里显示 即时 loss 与移动平均 loss
  • 每个 epoch 结束后,脚本会用一个“火柴盒”小屋做推理并把 6 × 6 横截面打印出来,方便肉眼检查模型是否能根据上下文填补缺失的方块 ID。
  1. 推理自己的建筑
    训练完后,把你的坐标张量送进 DiTMini 即可:
python
import torch from dit_blocktypes import DiTMini model = DiTMini().load_state_dict(torch.load('logs/lightning_logs/version_0/checkpoints/…')['state_dict']) model.eval() coords = torch.tensor([[10., 12., 5.], [11., 12., 5.]], dtype=torch.float32) with torch.no_grad(): preds = model(coords).argmax(dim=1) print(preds) # 输出对应方块 ID

有了这一脚本,你就可以在终端通过 --batch_size 自定义批大小,并且整个流程都遵循 PyTorch Lightning 的规范,方便继续扩展。祝你训练顺利!

इस प्रश्नोत्तर को साझा करें