Scaffold和tf.train.MonitoredTrainingSession

时间:2017-06-08 15:11:14

标签: tensorflow

我想知道如何在tf.train.MonitoredTrainingSession中使用Scaffold,并使用Numpy数组中的特定导入值初始化图形权重。我找不到任何类似用法的明确例子。感谢

1 个答案:

答案 0 :(得分:2)

So there are actually several way to go on doing this.

Saving a graph checkpoint approach

  • Construct your graph.
  • Init all variables.
  • Run a session to assign values to each variable.
  • save a checkpoint to be loaded at training time.
  • use the checkpoint at train time

Using the Model Initialization and Recovery

You can see more details here: Tensorflow Model Recovery . Basically, you can create the tf.train.Scaffold and assign the init_fn with your init function.

I only tested the first approach can share some code:

  with tf.Graph().as_default():

    # build the graph as it is in training
    some code...

    sess = tf.Session()
    with sess.as_default():

        # Add an op to initialize the variables.
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        #Update your graph with starting variables
        data_dict = np.load('your_pass/model.npy', encoding='latin1').item()
        #
        var = tf.get_variable(param_name)
        sess.run(var.assign(data_dict))
        print('assignment done!')

    saver = tf.train.Saver()

    # Save the variables to disk.
    save_path = saver.save(sess, FLAGS.train_dir)
    print("Model saved in file: %s" % save_path)