我对Tensorflow很新,并试图找到一种使用简单程序保存和恢复密集层的方法。我使用以下简单程序初始化并保存图形。
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32,name = "x" )
y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32, name = "y_t")
linear_model = tf.layers.Dense(units=1, name = "sutej")
y_pred = linear_model(x)
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for i in range(1000):
_, loss_value = sess.run((train, loss))
print(loss_value)
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(all_vars)
for v in all_vars:
v_ = sess.run(v)
print(v_)
print(sess.run(y_pred))
saver = tf.train.Saver()
saver.save(sess, '/home/sutej/Tensorflow/newsave/newsave',global_step=1000)
以下代码恢复图
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
sess = tf.Session()
saver = tf.train.import_meta_graph('/home/sutej/Tensorflow/newsave/newsave-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('/home/sutej/Tensorflow/newsave/'))
graph = tf.get_default_graph()
x=graph.get_tensor_by_name('x:0')
y_true=graph.get_tensor_by_name('y_t:0')
graph = tf.get_default_graph()
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(all_vars)
for v in all_vars:
v_ = sess.run(v)
print(v_)
print(sess.run('sutej/kernel:0'))
print(sess.run(tf.layers.dense(inputs=x,units=1,name = 'sutej', reuse=True)))
在输出中,我能够看到图形已经恢复,因此保存文件中的常量,偏差和权重也是如此。但是当我尝试通过密集层(我的恢复代码的最后一行)传递输入时,我收到一个错误。输出日志如下。
[<tf.Variable 'sutej/kernel:0' shape=(1, 1) dtype=float32_ref>, <tf.Variable 'sutej/bias:0' shape=(1,) dtype=float32_ref>]
[[-0.98440635]]
[0.95415276]
[[-0.98440635]]
Traceback (most recent call last):
File "load_saver.py", line 26, in <module>
print(sess.run(tf.layers.dense(inputs=x,units=1,name = 'sutej', reuse=True)))
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 253, in dense
return layer.apply(inputs)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 828, in apply
return self.__call__(inputs, *args, **kwargs)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 699, in __call__
self.build(input_shapes)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 138, in build
trainable=True)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 546, in add_variable
partitioner=partitioner)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/training/checkpointable.py", line 436, in _add_variable_with_custom_getter
**kwargs_for_getter)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1317, in get_variable
constraint=constraint)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1079, in get_variable
constraint=constraint)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 425, in get_variable
constraint=constraint)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 394, in _true_getter
use_resource=use_resource, constraint=constraint)
File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 751, in _get_single_variable
"reuse=tf.AUTO_REUSE in VarScope?" % name)
ValueError: Variable sutej/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?
我错了什么?我该如何解决这个问题?我不想手动繁殖内核并添加偏见,但我在代码的第一部分中寻找更像print(sess.run(y_pred))
的更优雅的东西。
感谢您的帮助。
答案 0 :(得分:0)
您想要的东西 - print(sess.run(y_pred))
- 正是您在恢复后应该做的事情。您导入了图形并恢复了变量。 TensorFlow世界几乎与您在训练脚本中的状态相同。您可以像在训练脚本中运行图形一样运行图形。
提高你的理解力。调用tf.layers.dense
创建变量和操作(在图中表示为节点)。因此,您在训练脚本中调用了它。导入图形并恢复变量值时,您创建了变量和操作。再次呼叫tf.layers.dense
毫无意义。
当您尝试使用字符串调用sess.run()
时,它不起作用,因为您需要传入张量对象。有几种方法可以获得所需的张量:
get_tensor_by_name
。这有点脆弱,因为当您处理模型时,张量名称可能会发生变化。 collection
,并在还原时从此集合中检索它们。集合名称应该比张量名称更稳定。y_pred
。在这种情况下,您需要对还原脚本进行更改。由于您使用代码创建了图表,因此无需import_meta_graph_def
。您只需要恢复变量值。