如何在程序中重用简单的密集Tensorflow图?

时间:2018-05-17 05:58:02

标签: python tensorflow

我对Tensorflow很新,并试图找到一种使用简单程序保存和恢复密集层的方法。我使用以下简单程序初始化并保存图形。

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf

x = tf.constant([[1], [2], [3], [4]], dtype=tf.float32,name = "x" )
y_true = tf.constant([[0], [-1], [-2], [-3]], dtype=tf.float32, name = "y_t")

linear_model = tf.layers.Dense(units=1, name = "sutej")

y_pred = linear_model(x)
loss = tf.losses.mean_squared_error(labels=y_true, predictions=y_pred)

optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

for i in range(1000):
  _, loss_value = sess.run((train, loss))
  print(loss_value)

all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(all_vars)
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

print(sess.run(y_pred))

saver = tf.train.Saver()

saver.save(sess, '/home/sutej/Tensorflow/newsave/newsave',global_step=1000)

以下代码恢复图

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf

sess = tf.Session()
saver = tf.train.import_meta_graph('/home/sutej/Tensorflow/newsave/newsave-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('/home/sutej/Tensorflow/newsave/'))

graph = tf.get_default_graph()
x=graph.get_tensor_by_name('x:0')
y_true=graph.get_tensor_by_name('y_t:0')

graph = tf.get_default_graph()

all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print(all_vars)
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

print(sess.run('sutej/kernel:0'))


print(sess.run(tf.layers.dense(inputs=x,units=1,name = 'sutej', reuse=True)))

在输出中,我能够看到图形已经恢复,因此保存文件中的常量,偏差和权重也是如此。但是当我尝试通过密集层(我的恢复代码的最后一行)传递输入时,我收到一个错误。输出日志如下。

[<tf.Variable 'sutej/kernel:0' shape=(1, 1) dtype=float32_ref>, <tf.Variable 'sutej/bias:0' shape=(1,) dtype=float32_ref>]
[[-0.98440635]]
[0.95415276]
[[-0.98440635]]
Traceback (most recent call last):
  File "load_saver.py", line 26, in <module>
    print(sess.run(tf.layers.dense(inputs=x,units=1,name = 'sutej', reuse=True)))
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 253, in dense
    return layer.apply(inputs)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 828, in apply
    return self.__call__(inputs, *args, **kwargs)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 699, in __call__
    self.build(input_shapes)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/core.py", line 138, in build
    trainable=True)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 546, in add_variable
    partitioner=partitioner)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/training/checkpointable.py", line 436, in _add_variable_with_custom_getter
    **kwargs_for_getter)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1317, in get_variable
    constraint=constraint)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1079, in get_variable
    constraint=constraint)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 425, in get_variable
    constraint=constraint)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 394, in _true_getter
    use_resource=use_resource, constraint=constraint)
  File "/home/sutej/.local/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 751, in _get_single_variable
    "reuse=tf.AUTO_REUSE in VarScope?" % name)
ValueError: Variable sutej/kernel does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?

我错了什么?我该如何解决这个问题?我不想手动繁殖内核并添加偏见,但我在代码的第一部分中寻找更像print(sess.run(y_pred))的更优雅的东西。

感谢您的帮助。

1 个答案:

答案 0 :(得分:0)

您想要的东西 - print(sess.run(y_pred)) - 正是您在恢复后应该做的事情。您导入了图形并恢复了变量。 TensorFlow世界几乎与您在训练脚本中的状态相同。您可以像在训练脚本中运行图形一样运行图形。

提高你的理解力。调用tf.layers.dense创建变量和操作(在图中表示为节点)。因此,您在训练脚本中调用了它。导入图形并恢复变量值时,您创建了变量和操作。再次呼叫tf.layers.dense毫无意义。

当您尝试使用字符串调用sess.run()时,它不起作用,因为您需要传入张量对象。有几种方法可以获得所需的张量:

  • 像你一样使用get_tensor_by_name。这有点脆弱,因为当您处理模型时,张量名称可能会发生变化。
  • 您可以在保存图表之前将张量添加到collection,并在还原时从此集合中检索它们。集合名称应该比张量名称更稳定。
  • 如果您有构建图表的代码(即您的培训脚本)。您可以使用该代码重新创建相同的图形。然后,您将拥有训练期间所拥有的相同python对象y_pred。在这种情况下,您需要对还原脚本进行更改。由于您使用代码创建了图表,因此无需import_meta_graph_def。您只需要恢复变量值。