我使用mx.io.DataIter
类编写了一个自定义数据迭代器。使用此Gluon
接口的数据迭代器最简单的方法是什么?
我浏览了文档,找不到一个简单的方法。我的一个想法是将它用作迭代器并从每个批处理中获取数据和标签,如下所示。
for e in range(epochs):
train_iter.reset()
for batch_data in train_iter:
data = nd.concatenate(([d for d in batch_data.data]))
label = nd.concatenate(([l for l in batch_data.label]))
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, label)
loss.backward()
trainer.step(batch_size)
print(nd.mean(loss).asscalar())
但这可能不是最佳的,因为我需要连接每批。
实现这一目标的最佳方法是什么?即是否有系统的 为胶子写一个简单的自定义迭代器的方法?
如何在上述情况下添加上下文信息?
答案 0 :(得分:1)
我认为你的方法有效。基本上,您可以从data
获取batch_data.data
,从label
获取batch_data.label
并将其提供给网络。
我不确定为什么你需要连接数据和标签 - 也许这与你的网络定义有关。
如果您需要拆分数据并在多个GPU上进行训练,可以使用gluon.utils.split_and_load函数来执行此操作。