MXnet:继续训练从文件加载的模型时出错

时间:2017-09-25 14:13:49

标签: python deep-learning mxnet

我也在repository发布了这个问题。

我正在尝试按照教程here加载以前保存在文件中的模型。我使用完全相同的命令,如教程中所示,但我遇到以下错误消息:

Traceback (most recent call last):
  File "test.py", line 153, in <module>
    num_epoch=num_epoch)
  File "/home/mypath/software/try_mxnet2/mxnet/python/mxnet/module/base_module.py", line 496, in fit
    self.update_metric(eval_metric, data_batch.label)
  File "/home/mypath/software/try_mxnet2/mxnet/python/mxnet/module/module.py", line 735, in update_metric
    self._exec_group.update_metric(eval_metric, labels)
  File "/home/mypath/software/try_mxnet2/mxnet/python/mxnet/module/executor_group.py", line 567, in update_metric
    for label, axis in zip(labels, self.label_layouts):
TypeError: zip argument #2 must support iteration

加载和重新训练文件的代码如下:

sym, arg_params, aux_params = mx.model.load_checkpoint('../model/test_mymodel', 25)
lenet_model = mx.mod.Module(symbol=sym, context=mx.gpu(), label_names=None)

lenet_model.bind(for_training=True, data_shapes=[('data', (batch_size,3,16,16))], 
         label_shapes=lenet_model._label_shapes)
lenet_model.set_params(arg_params, aux_params, allow_missing=True)
lenet_model.fit(train_iter,
                optimizer='adam',
                optimizer_params={'learning_rate':0.001,'wd':0.0005},
                eval_metric='acc',
                batch_end_callback = mx.callback.Speedometer(batch_size, n_report), 
                epoch_end_callback  = mx.callback.do_checkpoint("../model/test_mymodel", 5),
                num_epoch=num_epoch)

正如我测试的那样,当我注释掉lenet_model.fit(...)行时,没有报告错误。似乎加载的模型不能连续训练,或者我的代码有问题。

我期待着善意的解决方案。谢谢!

1 个答案:

答案 0 :(得分:1)

repo https://github.com/apache/incubator-mxnet/issues/8023中的响应澄清了这是代码中的形状不匹配错误。