在下面的代码中,我恢复了一个经过训练的模型,名为mymodel。并获得名称为'var_x'
的张量。然后我通过串联列表[3]来更改'var_x'
的值。最后,我保存了新模型。
如果tf.assign将变量名设置为'var_x'
,则模型newmodel中的值仍与模型mymodel中的值相同,而没有连接列表[3]。
如果tf.assign将变量名设置为'var_y'
或其他但不是'var_x'
,那么在恢复新模型时,我可以获得正确的'var_y'
值。
sess= tf.Session() # restore trained model--mymodel
saver = tf.train.import_meta_graph('mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint(myfileDir, latest_filename =
'mymodel-checkpoint'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('var_x:0')
saver1 = tf.train.Saver()
x_update=tf.concat([x,[3]]) #change variable with concatenating list
y = tf.Variable(tf.random_normal(shape=[4]), dtype = tf.float32)
tf.assign(y,x_update, name='var_x')
saver1.save(sess, 'newmodel', latest_filename='newmodel-checkpoint')
我不知道如何通过在模型newmodel的'var_x'
中串联[3]来保存代码以保存更改的值。
答案 0 :(得分:0)
tf.assign
不会立即执行分配。相反,它返回一个赋值操作,您可以用一个sess.run
调用来评估它。完成后,您的变量将具有x_update
的值。
您可以像这样直接分配给var_x:0
:
sess = tf.Session()
saver = tf.train.import_meta_graph('mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint(myfileDir, latest_filename='mymodel-checkpoint'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('var_x:0')
new_value = tf.concat([x, [3]], axis=0)
assign_op = tf.assign(x, new_value, validate_shape=False)
sess.run(assign_op)
saver1 = tf.train.Saver()
saver1.save(sess, 'newmodel', latest_filename='newmodel-checkpoint')
关键思想是x
是指var_x:0
,它已经是tf.Variable
。您需要在分配中使用validate_shape=False
,因为您正在更改var_x
的形状。