TensorFlow变量名称 - 保存/恢复中的分配错误

时间:2017-04-30 23:16:01

标签: tensorflow

我正在使用tf.train.Saver saverestore保存并恢复TensorFlow模型。在恢复过程中,我正在加载新的输入数据。 restore方法抛出此错误:

  

InvalidArgumentError(参见上面的回溯):Assign需要形状   两个张量匹配。 lhs shape = [1334,3] rhs shape = [1246,3]   [[节点:save / Assign_6 =分配[T = DT_FLOAT,_class = [“loc:@ Variable_2”],   use_locking = true,validate_shape = true,   _device =“/ job:localhost / replica:0 / task:0 / cpu:0”](Variable_2,save / RestoreV2_6)]]

这似乎表明问题出在Variable_2,但是如何确定代码中哪个变量对应Variable_2

2 个答案:

答案 0 :(得分:1)

  • 如果要恢复模型并进行前馈,那么模型形状和模型架构在保存时应该相同
  • 所以上面的错误是说当你恢复你的模型时,其中一个被保存的张量具有形状[1246,3],但是你将它分配给形状为[1334,3]的张量
  • 明确知道哪些变量属于哪个名称,您可以为张量分配唯一的名称,例如a = tf.placeholder("float", [3, 3], name="tensor_a")
  • 所以现在当你恢复模型时,你知道你的模型在图中有一个张量,名字=“tensor_a”,形状为3x3
  • 代码中的快速教程:

    # Create some variables.
    v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
    v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
    
    inc_v1 = v1.assign(v1+1)
    dec_v2 = v2.assign(v2-1)
    
    # Add an op to initialize the variables.
    init_op = tf.global_variables_initializer()
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, initialize the variables, do some work, and save the
    # variables to disk.
    with tf.Session() as sess:
        sess.run(init_op)
        # Do some work with the model.
        inc_v1.op.run()
        dec_v2.op.run()
        # Save the variables to disk.
        save_path = saver.save(sess, "/tmp/model.ckpt")
        print("Model saved in file: %s" % save_path)
    
    tf.reset_default_graph()
    
    # Create some variables.
    d1 = tf.get_variable("v1", shape=[3])
    d2 = tf.get_variable("v2", shape=[5])
    
    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()
    
    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:
        # Restore variables from disk.
        saver.restore(sess, "/tmp/model.ckpt")
        print("Model restored.")
        # Check the values of the variables
        print("v1 : %s" % d1.eval())
        print("v2 : %s" % d2.eval())
    
  • 如果您注意到上面的代码d1和v1现在具有相同的形状,如果您更改任何变量的形状,它将引发一个错误,类似于您得到的错误

答案 1 :(得分:-1)

当您创建一个新变量时,它将获得一个唯一的名称。 Saver.restore在检查点中查看相同的名称。如果您需要使用不同的名称从不同的检查点初始化某些变量,请查看tf.contrib.framework.init_from_checkpoint