使用MonitoredTrainingSession时如何恢复权重?

时间:2019-05-16 03:22:47

标签: tensorflow

我正在训练CNN网络,并且已经对网络的一部分进行了预训练,所以我希望将训练后的权重恢复到新网络。我的训练代码如下:

 FrameLayout placeHolder = findViewById(R.id.placeHolder);

    ImageView iv = findViewById(R.id.inner);

经过训练的权重将添加为以下内容:

def train():
 with tf.Graph().as_default():
  global_step = tf.train.get_or_create_global_step()

  with tf.device('/cpu:0'):
    signals, labels = cifar10.distorted_inputs()

  logits = cifar10.inference(signals)

  var_list=tf.contrib.framework.get_local_variables()    
  saver = tf.train.Saver(var_list)

  loss = cifar10.loss(logits, labels)

  train_op = cifar10.train(loss, global_step)  

  with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
           tf.train.NanTensorHook(loss)],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    saver.restore(mon_sess, ".../model.ckpt-100000")
    while not mon_sess.should_stop():
      mon_sess.run(train_op)

我的问题是,尽管它可以像普通的tensorflow项目一样运行,但我不确定使用MonitoredTrainingSession编写的代码是否正确。 或者如何调试以确保权重正确还原并且在运行训练代码时不会受到训练。 另一个问题是我不知道如何初始化变量。恢复后是否会初始化训练后的权重?感谢您的答复!

0 个答案:

没有答案