我一直在努力为3D张量管理和创建批处理。我之前曾使用它作为创建一维张量批处理的方法。但是,在我目前的研究中,我需要根据形状为(1024,1024,2)的张量创建批次。
我创建了自定义数据,用作pytorch中DataLoader方法的输入。我为一维数组创建了以下内容:
class CustomDataset(Dataset):
def __init__(self, x_tensor, y_tensor):
self.xdomain = x_tensor
self.ydomain = y_tensor
def __getitem__(self, index):
return (self.xdomain[index], self.ydomain[index])
def __len__(self):
return len(self.xdomain)
效果很好,但是,我意识到这不适用于形状分别为(1024,1024,2)和(1024,1024,1)的张量x_tensor和y_tensor。我知道我必须以某种方式更改__ getitem __和__ len __函数,以便可以将张量分成几批。
我尝试了很多事情,但是我知道可以起作用的是,我可以将这些张量展平为形状(1024 x1024,2)和(1024x1024,1)。但是,我不仅必须更改我的NN定义,还必须更改我的代码。
因此,我想保持原样,并尝试了解如何尽可能创建这些功能。我对这些功能的了解是:
__len__
,以便len(dataset)返回数据集的大小。
__getitem__
支持索引,以便数据集[i]可用于获取第i个样本。
基于此知识,我创建了此类,该类查找前两个维度的索引(以查找第ith个样本)。但是,这将创建NN的输入为(1024x1024,2)和输出(1024x1024,1)。我希望它是(1024,1024,2)和(1024,1024,1)。
如果对Data Loader和mini-batch有更好了解的人可以解释我所缺少的内容,那可能会很棒。首先有可能吗?
感谢您阅读本文,如果这个问题太基础了,请对不起。我希望是清楚的。