mxnet:如何使用训练有素的RNN模型进行预测

时间:2017-03-08 12:40:18

标签: python recurrent-neural-network mxnet

我正在测试mxnet的RNN模型。教程here不起作用,错误消息表示许多函数已被弃用。我没有找到RNN的最新教程。 mxnet项目中仍有一些示例。但对于RNN,examples仅显示如何使用训练集训练模型。他们没有展示如何使用训练模型进行进一步预测。培训代码如下:

model.fit(
    train_data          = data_train,
    eval_data           = data_val,
    eval_metric         = mx.metric.Perplexity(invalid_label),
    kvstore             = args.kv_store,
    optimizer           = args.optimizer,
    optimizer_params    = { 'learning_rate': args.lr,
                            'momentum': args.mom,
                            'wd': args.wd },
    initializer         = mx.init.Xavier(factor_type="in", magnitude=2.34),
    num_epoch           = args.num_epochs,
    batch_end_callback  = mx.callback.Speedometer(args.batch_size, args.disp_batches))

有人知道如何使用经过培训的 RNN模型进行推理或预测吗?

我必须澄清我正在寻找如何使用 RNN模型进行预测,而不是CNN或其他模型。

非常感谢你帮助我!!!

1 个答案:

答案 0 :(得分:1)

通常模型是扩展BaseModel类。 BaseModel有the method predict。该方法可以使用fit方法使用的相同类型:DataIter只有一个区别,它不需要train_data,只需要eval_data。因此,实际的预测过程可以通过以下简单方式实现:

result = mod.predict(dataiter.next)