我正在使用MXnet训练CNN(在R中),我可以使用以下代码训练模型而不会出现任何错误:
model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
但是由于这个过程非常耗时,我会在夜间在服务器上运行它,我希望在完成培训后保存模型以便使用它。
我用过:
save(list = ls(), file="mymodel.RData")
和
mx.model.save("mymodel", 10)
但他们都不能保存模型!例如,当我加载"mymodel.RData"
时,我无法预测测试集的标签!
另一个例子是当我加载"mymodel.RData"
并尝试使用以下代码绘制它时:
graph.viz(model$symbol$as.json())
我收到以下错误:
Error in model$symbol$as.json() : external pointer is not valid
有人可以给我一个保存然后加载此模型的解决方案以供将来使用吗?
由于
答案 0 :(得分:2)
您可以按
保存模型model <- mx.model.FeedForward.create(symbol=network,
X=train.iter,
ctx=mx.gpu(0),
num.round=20,
array.batch.size=batch.size,
learning.rate=0.1,
momentum=0.1,
eval.metric=mx.metric.accuracy,
wd=0.001,
epoch.end.callback=mx.callback.save.checkpoint("model_prefix")
batch.end.callback=mx.callback.log.speedometer(batch.size, frequency = 100)
)
答案 1 :(得分:0)
保存培训进度快照的最佳做法是在每次纪元培训后使用save_snapshot(http://mxnet.io/api/python/module.html#mxnet.module.Module.save_checkpoint)作为回调的一部分。在R中,等效命令可能是mx.callback.save.checkpoint,但我没有使用R并且不确定它的用法。
使用这些快照还可以让您利用AWS Spot市场(https://aws.amazon.com/ec2/spot/pricing/)的低成本选项,例如现在提供和16 K80 GPU的实例,每小时3.8美元,相比之下 - 要价14.4美元。只要您正确使用这些快照,这种80%-90%的折扣在现货市场中很常见,并且可以优化培训的速度和成本。
答案 2 :(得分:0)
mxnet模型是一个R列表,但是它的第一个组件不是R对象而是C ++指针,因此不能保存并重新加载为R对象。因此,需要对模型进行序列化以使其表现为实际的R对象。序列化的对象也是一个列表,但是它的第一个对象是包含模型信息的文本。
要保存模型:
modelR <- mx.serialize(model)
save(modelR, file="~/model1.RData")
要检索并再次使用它:
load("~/model1.RData", verbose=TRUE)
model <- mx.unserialize(modelR)