如何在张量流中向已加载的检查点添加一些新变量?

时间:2018-05-21 23:50:43

标签: python tensorflow save

我已经在tensorflow中训练了一个大图,并通过以下函数将它们保存在检查点中,

def save_model(sess, saver, param_folder, saved_ckpt):
    print("Saving model to disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    save_path = saver.save(sess, address)
    saver.export_meta_graph(filename=address+'.meta')
    print("Model saved in file: %s" % save_path)

现在,要加载图表,我使用了以下函数。

def load_model(sess, saver, param_folder, saved_ckpt):
    print("loding model from disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    print("meta graph address :", address)
    saver = tf.train.import_meta_graph(address+'.meta')
    saver.restore(sess, address)

TensorFlow的一个很棒的功能是它会自动将保存的权重分配给检查点所需的图形。但是当我想在与我保存的图形略有不同/扩展的图形中加载图形(保存在检查点中的图形)时会出现问题。比如,假设我在上一个图表中添加了一个额外的神经网络,并希望从前一个检查点加载权重,这样我就不必从头开始训练模型。

所以简而言之,我的问题是,如何将以前保存的子图加载到更大的(或者你可以说是父图)图中?

1 个答案:

答案 0 :(得分:2)

我也遇到了这个问题,我使用了@rvinas注释。因此,只是为了使下一个读者更容易。

在加载已保存的变量时,可以在restore_dict中添加/删除/编辑它们,如下所示:

def load_model(sess, saver, param_folder, saved_ckpt):
    print("loding model from disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    print("meta graph address :", address)
    # remove the next two lines
    # saver = tf.train.import_meta_graph(address+'.meta')
    # saver.restore(sess, address)
    # instead put this block:

    reader = tf.train.NewCheckpointReader(address)
    restore_dict = dict()
    for v in tf.trainable_variables():
      tensor_name = v.name.split(':')[0]
      if reader.has_tensor(tensor_name):
        print('has tensor ', tensor_name)
        restore_dict[tensor_name] = v
        # put the logic of the new/modified variable here and assign to the restore_dict, i.e. 
        # restore_dict['my_var_scope/my_var'] = get_my_variable()

希望有帮助。