我正在学习在模型训练和推理过程中管理批次的最佳方法和其他最佳实践,我有以下问题:
batch, label = batch.to(device), label.to(device)
model(batch)
# ..Training pass...
batch, label = batch.cpu(), label.cpu()
如果我在 Dataset
类中缓存我的数据,我如何确保可以在 GPU 上重复使用相同的批次以避免在 CPU 之间多次传输?答案 0 :(得分:0)
您不应该将数据移回 CPU。 GPU 上的数据分配由 PyTorch 处理。您应该使用 torch.utils.data.DataLoader
来处理从数据集加载的数据。但是,您必须自己在 GPU 上发送输入:基本上,每次您需要推断一些输出时,您都会将批次和标签发送到 GPU 并计算结果(和损失),仅此而已。