我需要训练保存在.params中的模型。我使用了this post,但它可以进行推理并且不训练。
我的代码如下:
sym, arg_params, aux_params = mx.model.load_checkpoint(
prefix, 0)
# Dropping the loss from model
new_sym = sym.get_children()[0]
net = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var("data"))
net.initialize(mx.init.Normal(0.002), ctx=ctx)
我也尝试过:
net = gluon.nn.SymbolBlock.imports("net-symbol.json"),
["data"],
"net-0000.params"),
ctx=ctx)
然后,我定义损失,优化器等。一切正常,但模型权重不会改变。
此外,当我用单个转换层替换模型时,它也可以工作。因此,我想有些事情会限制我加载网的重量。