如何在使用MXnet时保存模型

时间:2017-04-20 11:14:27

标签: r deep-learning mxnet

我正在使用MXnet训练CNN(在R中),我可以使用以下代码训练模型而不会出现任何错误:

model <- mx.model.FeedForward.create(symbol=network,
                                     X=train.iter,
                                     ctx=mx.gpu(0),
                                     num.round=20,
                                     array.batch.size=batch.size,
                                     learning.rate=0.1,
                                     momentum=0.1,  
                                     eval.metric=mx.metric.accuracy,
                                     wd=0.001,
                                     batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
    )

但是由于这个过程非常耗时,我会在夜间在服务器上运行它,我希望在完成培训后保存模型以便使用它。

我用过:

save(list = ls(), file="mymodel.RData")

mx.model.save("mymodel", 10)

但他们都不能保存模型!例如,当我加载"mymodel.RData"时,我无法预测测试集的标签!

另一个例子是当我加载"mymodel.RData"并尝试使用以下代码绘制它时:

graph.viz(model$symbol$as.json())

我收到以下错误:

Error in model$symbol$as.json() : external pointer is not valid

有人可以给我一个保存然后加载此模型的解决方案以供将来使用吗?

由于

3 个答案:

答案 0 :(得分:2)

您可以按

保存模型
model <- mx.model.FeedForward.create(symbol=network,
                                 X=train.iter,
                                 ctx=mx.gpu(0),
                                 num.round=20,
                                 array.batch.size=batch.size,
                                 learning.rate=0.1,
                                 momentum=0.1,  
                                 eval.metric=mx.metric.accuracy,
                                 wd=0.001,
                                 epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
                                 batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)

答案 1 :(得分:0)

保存培训进度快照的最佳做法是在每次纪元培训后使用save_snapshot(http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint)作为回调的一部分。在R中,等效命令可能是mx.callback.save.checkpoint,但我没有使用R并且不确定它的用法。

使用这些快照还可以让您利用AWS Spot市场(https://aws.amazon.com/ec2/spot/pricing/)的低成本选项,例如现在提供和16 K80 GPU的实例,每小时3.8美元,相比之下 - 要价14.4美元。只要您正确使用这些快照,这种80%-90%的折扣在现货市场中很常见,并且可以优化培训的速度和成本。

答案 2 :(得分:0)

mxnet模型是一个R列表,但是它的第一个组件不是R对象而是C ++指针,因此不能保存并重新加载为R对象。因此,需要对模型进行序列化以使其表现为实际的R对象。序列化的对象也是一个列表,但是它的第一个对象是包含模型信息的文本。

要保存模型:

modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")

要检索并再次使用它:

load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)