在MXNet中使用高级API时,如何将其他数据输入网络

时间:2018-07-14 11:43:49

标签: deep-learning mxnet

我正在研究MXNet框架,我需要在每次迭代期间将矩阵输入网络。矩阵存储在外部存储器中,不是训练数据,而是在每次迭代结束时由网络输出更新的。在迭代过程中,必须将矩阵输入网络。

如果我使用高级API,即

model = mx.mod.Module(context=ctx, symbol=sym) ... ... model.fit(train_data_iter, begin_epoch=begin_epoch, end_epoch=end_epoch, ......)

这可以实现吗?

1 个答案:

答案 0 :(得分:1)

model.fit()未提供您想要的功能。但是,要实现的目标在Apache MXNet的Gluon API中非常容易实现。使用Gluon API,您可以为训练循环编写7行代码,而不必使用单个model.fit()。这是一个典型的训练循环代码:

for epoch in range(10):
    for data, label in train_data:
        # forward + backward
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(batch_size)  # update parameters

因此,如果您想将网络的输出反馈回输入中,则可以轻松实现。要开始使用Gluon,我建议使用60-minute Gluon Crash Course。要成为Gluon的专家,我推荐Deep Learning - The Straight Dope书以及MXNet主要网站上的一整套教程:http://mxnet.apache.org/tutorials/index.html