假设我有一个火炬 dataloader = DataLoader(...)
对象。每当我在函数中调用 for data, label in dataloader:
时,我都不想遍历整个数据集,所以目前我使用:
dataloader = DataLoader(...)
iter_dataloader = iter(dataloader)
batch = iter_dataloader.next() # Set the first batch
def train_batch():
data, label = batch
prediction = model(data)
# Do fancy things here
try:
batch = iter_dataloader.next() # Load the next batch
except:
iter_dataloader = iter(dataloader) # if the iterator object reaches the end, reset the dataloader
batch = iter_dataloader.next()
for _ in range(N):
train_batch() # This function is called multiple times
对于 train_batch()
的每次调用,我从数据集中获取一批,训练模型,然后加载下一批。如果没有剩余批次,我会重置 DataLoader 对象。
现在我的问题:
iter
和 next
方法。每次我调用它时,它都会自动从中采样一批,并在到达末尾时自动重置。我听说过 Sampler
,但我没有使用过。K
批或 1/K
来代替批处理吗?