MXnet:加载检查站和火车

时间:2019-07-17 15:45:39

标签: mxnet

我需要训练保存在.params中的模型。我使用了this post,但它可以进行推理并且不训练。

我的代码如下:

sym, arg_params, aux_params = mx.model.load_checkpoint(
        prefix, 0)
    # Dropping the loss from model
    new_sym = sym.get_children()[0]
    net = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var("data"))
    net.initialize(mx.init.Normal(0.002), ctx=ctx)

我也尝试过:

   net = gluon.nn.SymbolBlock.imports("net-symbol.json"),
            ["data"],
            "net-0000.params"),
            ctx=ctx)

然后,我定义损失,优化器等。一切正常,但模型权重不会改变。

此外,当我用单个转换层替换模型时,它也可以工作。因此,我想有些事情会限制我加载网的重量。

0 个答案:

没有答案