Mxnet Gluon自定义数据迭代器

时间:2017-11-13 20:02:24

标签: mxnet

我使用mx.io.DataIter类编写了一个自定义数据迭代器。使用此Gluon接口的数据迭代器最简单的方法是什么?

我浏览了文档,找不到一个简单的方法。我的一个想法是将它用作迭代器并从每个批处理中获取数据和标签,如下所示。

for e in range(epochs):
    train_iter.reset()
    for batch_data in train_iter:
        data = nd.concatenate(([d for d in batch_data.data]))
        label = nd.concatenate(([l for l in batch_data.label]))
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(batch_size)
        print(nd.mean(loss).asscalar())

但这可能不是最佳的,因为我需要连接每批。

  1. 实现这一目标的最佳方法是什么?即是否有系统的 为胶子写一个简单的自定义迭代器的方法?

  2. 如何在上述情况下添加上下文信息?

1 个答案:

答案 0 :(得分:1)

我认为你的方法有效。基本上,您可以从data获取batch_data.data,从label获取batch_data.label并将其提供给网络。

我不确定为什么你需要连接数据和标签 - 也许这与你的网络定义有关。

如果您需要拆分数据并在多个GPU上进行训练,可以使用gluon.utils.split_and_load函数来执行此操作。