使用服装尺寸标签创建dataLoader

时间:2020-06-26 16:12:26

标签: python pytorch dataloader

我有一个标签大小为64的数据集,输入是15*17通道中大小为3的图像。因此,特征(图像)和标签的大小分别为n*3*15*17n*64。我想创建一个pytorch dataLoader以获得数据的小批量。我可以通过将标签的尺寸扩展到n*3*15*64并将其与特征(图像)连接以获取批次,然后在将它们传递到模型之前进行分离来做到这一点。但是,我认为这不是一个好主意,因为当n很大时,这在内存使用方面可能会非常昂贵,并且还会涉及一些不必要的计算。

我喜欢任何建议吗?

这是一个示例:

import torch import numpy as np 

n = 10 
a = np.ones((n,3,15,17))
b = np.ones((n,64))

data = torch.utils.data.DataLoader(train_data, 5, shuffle = True)

我不确定如何从train_dataa中构造b

0 个答案:

没有答案