同一个Session.run()调用中的多个顺序Tensorflow操作

时间:2017-12-21 09:59:29

标签: tensorflow

正如标题所示,我想在同一个Session.run()调用中运行多个Tensorflow操作。具体来说,为了使问题更具体,假设我想在一次调用中运行多个训练迭代。

使用多个Session.run()调用执行此操作的标准方法是:

# Declare the function that we want to minimize
func = ...

# Create the optimizer which will perform a single optimization iteration
optimizer = tf.train.AdamOptimizer().minimize(func)

# Run N optimization iterations
N = 10
with tf.Session() as sess:

    sess.run( tf.global_variables_initializer() )
    for i in range(N):
        sess.run( optimizer )

但是,这当然会产生一些开销,因为我们正在进行多个会话调用。我假设我们可以通过某种方式对操作进行分组来消除一些重要的开销。我假设groupcount_up_to是我应该使用的,但是我找不到任何示例来演示如何在这种情况下使用它们。有人可以指点我正确的方向吗?

最终目标是定义一些复合操作,该操作将在一次调用中运行N次迭代,以便上述内容可以转换为如下所示:

# Declare the function that we want to minimize
func = ...

# Create the optimizer which will perform a single optimization iteration
optimizer = tf.train.AdamOptimizer().minimize(func)

# Create the compound operation that will run the optimizer 10 times
optimizeNIterations = ?????
with tf.Session() as sess:

    sess.run( tf.global_variables_initializer() )
    sess.run( optimizeNIterations )

EDIT ::

正如musically_ut指出的那样,我确实可以通过强制问题来提供饲料字典来将操作链接在一起。但这感觉就像解决了一个非常具体的问题。我总体关注的是如何在单个会话运行中顺序执行操作。我可以举出另一个例子,为什么你会想要这个....

现在假设除了想要运行我的优化器之外,我想要检索优化值,让我们说它们位于变量X中。如果我想优化并获得优化值,我可以尝试做这样的事情

with tf.Session() as sess:

    sess.run( tf.global_variables_initializer() )
    o, x = sess.run( [ optimizer, X ] )

但实际上这不起作用,因为操作(优化器,X)不按顺序运行。我基本上需要进行2次会话:

with tf.Session() as sess:

    sess.run( tf.global_variables_initializer() )
    o = sess.run( optimizer )
    x = sess.run( X )

问题是如何将这两个调用合二为一。

1 个答案:

答案 0 :(得分:3)

听起来你可以在tf.while_loop中多次执行你想要运行的操作。如果操作是独立的,则可能必须将parallel_iterations设置为1或(更好)使用控制依赖项来对优化程序调用进行排序。例如:

import tensorflow as tf

with tf.Graph().as_default():
  opt = tf.train.AdamOptimizer(0.1)
  # Use a resource variable for a true "read op"
  var = tf.get_variable(name="var", shape=[], use_resource=True)
  def _cond(i, _):
    return tf.less(i, 20)  # 20 iterations
  def _body(i, sequencer):
    with tf.control_dependencies([sequencer]):
      loss = .5 * (var - 10.) ** 2
      print_op = tf.Print(loss, ["Evaluating loss", i, loss])
    with tf.control_dependencies([print_op]):
      train_op = opt.minimize(loss)
    with tf.control_dependencies([train_op]):
      next_sequencer = tf.ones([])
    return i + 1, next_sequencer
  initial_value = var.read_value()
  with tf.control_dependencies([initial_value]):
    _, sequencer = tf.while_loop(cond=_cond, body=_body, loop_vars=[0, 1.])
  with tf.control_dependencies([sequencer]):
    final_value = var.read_value()
  init_op = tf.global_variables_initializer()
  with tf.Session() as session:
    session.run([init_op])
    print(session.run([initial_value, final_value]))

打印:

2017-12-21 11:40:35.920035: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][0][46.3987083]
2017-12-21 11:40:35.920317: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][1][45.4404]
2017-12-21 11:40:35.920534: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][2][44.4923515]
2017-12-21 11:40:35.920715: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][3][43.55476]
2017-12-21 11:40:35.920905: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][4][42.6277695]
2017-12-21 11:40:35.921084: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][5][41.711544]
2017-12-21 11:40:35.921273: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][6][40.8062363]
2017-12-21 11:40:35.921426: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][7][39.9120026]
2017-12-21 11:40:35.921578: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][8][39.028965]
2017-12-21 11:40:35.921732: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][9][38.1572723]
2017-12-21 11:40:35.921888: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][10][37.2970314]
2017-12-21 11:40:35.922053: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][11][36.4483566]
2017-12-21 11:40:35.922187: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][12][35.6113625]
2017-12-21 11:40:35.922327: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][13][34.7861366]
2017-12-21 11:40:35.922472: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][14][33.9727631]
2017-12-21 11:40:35.922613: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][15][33.1713257]
2017-12-21 11:40:35.922777: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][16][32.3818779]
2017-12-21 11:40:35.922942: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][17][31.6044941]
2017-12-21 11:40:35.923115: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][18][30.8392067]
2017-12-21 11:40:35.923253: I tensorflow/core/kernels/logging_ops.cc:79] [Evaluating loss][19][30.0860634]
[0.36685812, 2.3390481]