使用TensorFlow更快地进行循环

时间:2018-12-23 10:49:44

标签: python tensorflow

我目前正在使用tensorflow测试某些东西,并且建立了一个小神经网络。为了训练它,我在优化器对象上反复运行会话:

error = 0.5 * tf.square(tf.subtract(output_layer, supervised_layer))
optimizer = tf.train.GradientDescentOptimizer(0.05).minimize(error)

session.run(tf.global_variables_initializer())

for i in range(1000):
    session.run([optimizer, error], feed_dict={
        input_layer: [[0, 1], [1, 0], [0, 0], [1, 1]],
        supervised_layer: [[1], [1], [0], [0]]
    })

我觉得从python经常调用session.run()并没有很好地利用tensorflow的工作方式,并且可能无法利用GPU的处理优势。

所以我的问题现在很简单:session.run()上是否有任何开销,并且有一种更好的方式来经常运行会话。

2 个答案:

答案 0 :(得分:1)

在典型情况下,单个session.run最多需要花费几秒钟的时间,因此执行python循环的成本可以忽略不计。因此简短的答案是您不必担心,因为在需要时,您一定会从GPU加速中受益。您已经分批执行操作。我建议简单地增加批量大小。

但是,如果您完全确定需要在一个session.run调用中执行多次更新,则可以使用多种方法在tensorflow(tf.while_loopautograph)中实现循环。因此,您实际上可以在一个session.run中多次调用任意代码。请注意,通常这是相当困难的,但是在大多数实际情况下,您可以找到更简单的解决方案,例如为RNNs

答案 1 :(得分:0)

在一般情况下,使用sess.run()而不使用close()后缀会出现问题。

会话可能拥有资源,例如tf.Variable,因此在不再需要这些资源时释放它们很重要。为此,请在会话上调用tf.Session.close()方法,或将该会话用作context manager

以下两个示例是等效的:

# Using the `close()` method.
sess = tf.Session()
sess.run(...)
sess.close()

# Using the context manager.
with tf.Session() as sess:
sess.run(...)

经常致电session.run() ...

Session类用于运行TensorFlow操作,并且会话对象封装了在其中执行Operation对象并评估Tensor对象的环境。

如果您可以在不创建环境的情况下完成计算,则可以说没有session.run()会更明智。

注意:Session.run()方法运行TensorFlow计算的一个“步骤”,方法是运行必要的图形片段以执行每个Operation并评估Tensor中的每个fetches,然后替换feed_dict中对应输入值的值。

fetches参数可以是单个图形元素,也可以是在其叶子处包含图形元素的任意嵌套的列表,元组,namedtuple,dict或OrderedDict。图元素可以是以下类型之一:

  • 一个tf.Operation。相应的获取值为None
  • 一个tf.Tensor。相应的获取值将是一个numpy ndarray,其中包含   张量的值。
  • 一个tf.SparseTensor。相应的获取值将是tf.SparseTensorValue,其中包含该稀疏张量的值。
  • 一个get_tensor_handle前页。相应的获取值将是   包含该张量的句柄的numpy ndarray。
  • 一个string,它是图中的张量或操作的名称。

更新:交互式会话也无济于事。与普通Session相比,唯一的区别是InteractiveSession使其成为默认会话,因此您可以使用会话对象显式调用变量上的run()eval()

#start the session
sess = tf.InteractiveSession()

# stop the session
sess.stop()
ops.reset_default_graph()