如何使用tf.cond进行批处理

时间:2017-02-10 12:56:18

标签: tensorflow conditional

我想将tf.cond(pred, fn1, fn2, name=None)用于条件分支。假设我有两个张量:x, y。每个张量都是0/1的批量,我想使用这个张量压缩x < y作为源 tf.cond pred参数:

  

pred:标量,用于确定是否返回fn1或fn2的结果。

但是如果我正在使用批处理,那么看起来我需要迭代图中的源张量并批量处理每个项目的切片并对每个项目应用tf.cond。看起来很可疑。为什么tf.cond不接受批处理而只接受标量?您能否告知批量使用它的正确方法是什么?

1 个答案:

答案 0 :(得分:7)

tf.where听起来像你想要的:张量之间的矢量化选择。

tf.cond是一个控制流修饰符:它确定哪些操作已执行,因此很难想到有用的批处理语义。

我们还可以将这些操作混合在一起:根据条件进行切片并将这些切片传递给两个分支的操作​​。

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
  """Split `full_input` between `true_branch` and `false_branch` on `condition`.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between `true_branch` and
      `false_branch` based on `condition`.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as `full_input`. Receives
      slices of `full_input` corresponding to the True entries of
      `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like `true_branch`, but receives inputs corresponding to the
      false elements of `condition`. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of `true_branch`), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor
    having shape [B_1, ..., B_N, ...].
  """
  full_input_flat = nest.flatten(full_input)
  true_indices = tf.where(condition)
  false_indices = tf.where(tf.logical_not(condition))
  true_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                     for input_tensor in full_input_flat])
  false_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                     for input_tensor in full_input_flat])
  true_outputs = true_branch(true_branch_inputs)
  false_outputs = false_branch(false_branch_inputs)
  nest.assert_same_structure(true_outputs, false_outputs)
  def scatter_outputs(true_output, false_output):
    batch_shape = tf.shape(condition)
    scattered_shape = tf.concat(
        [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
        0)
    true_scatter = tf.scatter_nd(
        indices=tf.cast(true_indices, tf.int32),
        updates=true_output,
        shape=scattered_shape)
    false_scatter = tf.scatter_nd(
        indices=tf.cast(false_indices, tf.int32),
        updates=false_output,
        shape=scattered_shape)
    return true_scatter + false_scatter
  result = nest.pack_sequence_as(
      structure=true_outputs,
      flat_sequence=[
          scatter_outputs(true_single_output, false_single_output)
          for true_single_output, false_single_output
          in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
  return result

一些例子:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0),
    full_input=tf.range(10, dtype=tf.float32),
    true_branch=lambda x: 0.2 + x,
    false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
               * tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0),
    full_input=cross_range,
    true_branch=lambda x: -x,
    false_branch=lambda x: x + 0.1)

with tf.Session():
  print(vector_test.eval())
  print(matrix_test.eval())

打印:

[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
  6.19999981  7.0999999   8.19999981  9.10000038]
[[  0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.        ]
 [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
    5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
 [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
   10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
 [  0.          -3.          -6.          -9.         -12.         -15.
  -18.         -21.         -24.         -27.        ]
 [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
   20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
 [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
   25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
 [  0.          -6.         -12.         -18.         -24.         -30.
  -36.         -42.         -48.         -54.        ]
 [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
   35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
 [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
   40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
 [  0.          -9.         -18.         -27.         -36.         -45.
  -54.         -63.         -72.         -81.        ]]