如何还原,更新和保存经过训练的模型

时间:2019-01-08 04:28:56

标签: python tensorflow

在下面的代码中,我恢复了一个经过训练的模型,名为mymodel。并获得名称为'var_x'的张量。然后我通过串联列表[3]来更改'var_x'的值。最后,我保存了新模型。

如果tf.assign将变量名设置为'var_x',则模型newmodel中的值仍与模型mymodel中的值相同,而没有连接列表[3]。

如果tf.assign将变量名设置为'var_y'或其他但不是'var_x',那么在恢复新模型时,我可以获得正确的'var_y'值。

sess= tf.Session() # restore trained model--mymodel
saver = tf.train.import_meta_graph('mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint(myfileDir, latest_filename = 
    'mymodel-checkpoint'))
graph = tf.get_default_graph()
x = graph.get_tensor_by_name('var_x:0')
saver1 = tf.train.Saver()
x_update=tf.concat([x,[3]]) #change variable with concatenating list
y = tf.Variable(tf.random_normal(shape=[4]), dtype = tf.float32)
tf.assign(y,x_update, name='var_x')
saver1.save(sess, 'newmodel', latest_filename='newmodel-checkpoint')

我不知道如何通过在模型newmodel的'var_x'中串联[3]来保存代码以保存更改的值。

1 个答案:

答案 0 :(得分:0)

tf.assign不会立即执行分配。相反,它返回一个赋值操作,您可以用一个sess.run调用来评估它。完成后,您的变量将具有x_update的值。

您可以像这样直接分配给var_x:0

sess = tf.Session()
saver = tf.train.import_meta_graph('mymodel.meta')
saver.restore(sess, tf.train.latest_checkpoint(myfileDir, latest_filename='mymodel-checkpoint'))
graph = tf.get_default_graph()

x = graph.get_tensor_by_name('var_x:0')
new_value = tf.concat([x, [3]], axis=0)
assign_op = tf.assign(x, new_value, validate_shape=False)
sess.run(assign_op)

saver1 = tf.train.Saver()
saver1.save(sess, 'newmodel', latest_filename='newmodel-checkpoint')

关键思想是x是指var_x:0,它已经是tf.Variable。您需要在分配中使用validate_shape=False,因为您正在更改var_x的形状。