我有一个自定义的整理功能,该功能可以更改批处理中的张量形状(填充可变长度等)。当我运行两次迭代(或多个时期)时,它将引发错误。经过调查,这是因为第二次迭代的输入是第一次迭代的输出!因此,它要进行两次整理功能。如何避免呢?
train_dataset = DataLoader(dataset=train_set, collate_fn=ER_Collate,
batch_size=BATCH_SIZE, shuffle=True)
for i, train_ in enumerate(train_dataset):
pass
for i, train_ in enumerate(train_dataset)://error thrown
pass