DataLoader中的自定义collat​​e_fn更改数据结构

时间:2019-06-04 14:11:06

标签: pytorch

我有一个自定义的整理功能,该功能可以更改批处理中的张量形状(填充可变长度等)。当我运行两次迭代(或多个时期)时,它将引发错误。经过调查,这是因为第二次迭代的输入是第一次迭代的输出!因此,它要进行两次整理功能。如何避免呢?

train_dataset = DataLoader(dataset=train_set, collate_fn=ER_Collate,
                               batch_size=BATCH_SIZE, shuffle=True)
for i, train_ in enumerate(train_dataset):
   pass
for i, train_ in enumerate(train_dataset)://error thrown
   pass

0 个答案:

没有答案