在TensorFlow中检测最后一个维度是1还是5?

时间:2017-11-24 15:57:18

标签: tensorflow

我正在编写TensorFlow(python)逻辑来确定张量的最后一个维度是1还是5.如果张量是标量,则该表达式应为false。在图形构造时,张量的形状是未知的。

鉴于张量input,我试过

tf.logical_and(
  # The tensor must not be a scalar.
  tf.greater(tf.rank(input), 0),
  # Check the last dimension.
  tf.logical_or(
    tf.equal(tf.shape(input)[-1], 1),
    tf.equal(tf.shape(input)[-1], 5)
  )
)

但是,当input张量是标量时,此逻辑会引发错误,因为表达式的tf.greater(tf.rank(input), 0)部分无法导致TensorFlow短路(并避免执行tf.logical_or图的一部分)。这是预期的行为。

有没有什么方法可以找到张量的最后一个维度但是逻辑优雅地处理输入张量是标量的情况?

也许比如有一种方法来强制执行控制依赖,导致检查等级首先运行?

我想我可以在这里使用tf.cond,但我对调用lambda函数如何更改图表感到有些不安。

2 个答案:

答案 0 :(得分:0)

您可以简单地扩展输入张量的调整以使其在所有情况下都有效(即将标量大小写转换为张量):

import tensorflow as tf
import numpy as np

input = tf.placeholder(tf.float32)
input_expanded = tf.expand_dims(input, axis=0)
last_dim_size = tf.shape(input_expanded)[-1]

result_tensor = tf.logical_and(
    # The tensor must not be a scalar.
    tf.greater(tf.rank(input_expanded) - 1, 0),
    # Check the last dimension.
    tf.logical_or(
        tf.equal(last_dim_size, 1),
        tf.equal(last_dim_size, 5)
    )
)

with tf.Session() as sess:
    for input_value in [1, np.zeros((2,)), np.zeros((1, 5)), np.zeros((1, 2, 6))]:
        result = sess.run(result_tensor, feed_dict={input: input_value})
        print('Input: {}'.format(input_value))
        print('Output: {}'.format(result))
        print()

输出:

Input: 1
Output: False

Input: [ 0.  0.]
Output: False

Input: [[ 0.  0.  0.  0.  0.]]
Output: True

Input: [[[ 0.  0.  0.  0.  0.  0.]
  [ 0.  0.  0.  0.  0.  0.]]]
Output: False

答案 1 :(得分:0)

在TF2中,您可以简单地使用.shape并检查最后一个([-1])尺寸是否为您要寻找的尺寸。您不必运行会话,因为版本2中默认启用了渴望执行。

some_tensor.shape[-1] in [1, 5]

可乐演示:https://colab.research.google.com/drive/1L4XD04XBuPSBeaB7Bb-twpHJCyYPkZrt